diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 0eced80cc15b8..56d3416c5f7c3 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -67,6 +67,7 @@ version = "3.2.0" dependencies = [ "a2wsgi>=1.10.8", + "cachetools>=5.0.0", # aiosqlite 0.22.0 has a problem with hanging pytest sessions and we excluded it # See https://github.com/omnilib/aiosqlite/issues/369 # It seems that while our test issues are fixed in 0.22.1, sqlalchemy 2 itself diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py b/airflow-core/src/airflow/api_fastapi/common/dagbag.py index ce81f6906b721..563af29c08d4b 100644 --- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py +++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py @@ -16,21 +16,42 @@ # under the License. from __future__ import annotations +import logging from typing import TYPE_CHECKING, Annotated from fastapi import Depends, HTTPException, Request, status from sqlalchemy.orm import Session +from airflow.configuration import conf from airflow.models.dagbag import DBDagBag if TYPE_CHECKING: from airflow.models.dagrun import DagRun from airflow.serialization.definitions.dag import SerializedDAG +log = logging.getLogger(__name__) + def create_dag_bag() -> DBDagBag: - """Create DagBag to retrieve DAGs from the database.""" - return DBDagBag() + """Create DagBag with configurable LRU+TTL caching for API server usage.""" + cache_size = conf.getint("api", "dag_cache_size", fallback=64) + cache_ttl_config = conf.getint("api", "dag_cache_ttl", fallback=3600) + + if cache_size < 0: + log.warning("dag_cache_size must be >= 0, disabling cache") + cache_size = 0 + if cache_ttl_config < 0: + log.warning("dag_cache_ttl must be >= 0, disabling TTL") + cache_ttl_config = 0 + + # Disable caching if cache_size is 0 + if cache_size <= 0: + return DBDagBag(cache_size=0) + + # Disable TTL if cache_ttl is 0 + cache_ttl: int | None = cache_ttl_config if cache_ttl_config > 0 else None + + return DBDagBag(cache_size=cache_size, cache_ttl=cache_ttl) def dag_bag_from_app(request: Request) -> DBDagBag: diff --git a/airflow-core/src/airflow/cli/commands/api_server_command.py b/airflow-core/src/airflow/cli/commands/api_server_command.py index 1753499916a45..5d8417b405355 100644 --- a/airflow-core/src/airflow/cli/commands/api_server_command.py +++ b/airflow-core/src/airflow/cli/commands/api_server_command.py @@ -82,6 +82,9 @@ def _run_api_server(args, apps: str, num_workers: int, worker_timeout: int, prox # Control access log based on uvicorn log level - disable for ERROR and above access_log_enabled = uvicorn_log_level not in ("error", "critical", "fatal") + # Worker recycling: restart workers after N requests to prevent memory accumulation + worker_max_requests = conf.getint("api", "worker_max_requests", fallback=10000) + uvicorn_kwargs = { "host": args.host, "port": args.port, @@ -95,6 +98,8 @@ def _run_api_server(args, apps: str, num_workers: int, worker_timeout: int, prox "log_level": uvicorn_log_level, "proxy_headers": proxy_headers, } + if worker_max_requests > 0: + uvicorn_kwargs["limit_max_requests"] = worker_max_requests # Only set the log_config if it is provided, otherwise use the default uvicorn logging configuration. if args.log_config and args.log_config != "-": # The [api/log_config] is migrated from [api/access_logfile] and [api/access_logfile] defaults to "-" for stdout for Gunicorn. diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index b26792b575c0b..6953180539998 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1446,6 +1446,33 @@ api: type: string example: ~ default: "False" + dag_cache_size: + description: | + Size of the LRU cache for SerializedDAG objects in the API server. + Set to 0 to disable caching. When caching is enabled, DAGs are cached + in memory to avoid repeated database queries and deserialization. + version_added: 3.2.0 + type: integer + example: ~ + default: "64" + dag_cache_ttl: + description: | + Time-to-live (seconds) for cached SerializedDAG objects in the API server. + After this time, cached DAGs will be re-fetched from the database on next access. + Set to 0 to disable TTL (cache entries will only be evicted by LRU policy). + version_added: 3.2.0 + type: integer + example: ~ + default: "3600" + worker_max_requests: + description: | + Maximum number of requests a worker will handle before being recycled. + This helps prevent memory growth from long-running processes by periodically + restarting workers. Set to 0 to disable worker recycling. + version_added: 3.2.0 + type: integer + example: ~ + default: "10000" base_url: description: | The base url of the API server. Airflow cannot guess what domain or CNAME you are using. diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index abf8c41760372..e2ea93bf1c773 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -18,12 +18,17 @@ from __future__ import annotations import hashlib +import logging +from collections.abc import MutableMapping +from threading import RLock from typing import TYPE_CHECKING, Any +from cachetools import LRUCache, TTLCache from sqlalchemy import String, inspect, select from sqlalchemy.orm import Mapped, joinedload from sqlalchemy.orm.attributes import NO_VALUE +from airflow._shared.observability.metrics.stats import Stats from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion from airflow.utils.sqlalchemy import mapped_column @@ -37,34 +42,110 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.serialization.definitions.dag import SerializedDAG +log = logging.getLogger(__name__) + class DBDagBag: """ - Internal class for retrieving and caching dags in the scheduler. + Internal class for retrieving dags from the database. + + Optionally supports LRU+TTL caching when cache_size is provided. + The scheduler uses this without caching, while the API server can + enable caching via configuration. :meta private: """ - def __init__(self, load_op_links: bool = True) -> None: - self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag + def __init__( + self, + load_op_links: bool = True, + cache_size: int | None = None, + cache_ttl: int | None = None, + ) -> None: + """ + Initialize DBDagBag. + + :param load_op_links: Should the extra operator link be loaded when de-serializing the DAG? + :param cache_size: Size of LRU cache. If None or 0, uses unbounded dict (no eviction). + :param cache_ttl: Time-to-live for cache entries in seconds. If None or 0, no TTL (LRU only). + """ self.load_op_links = load_op_links + self._dags: MutableMapping[str, SerializedDAG] = {} + self._lock: RLock | None = None + self._use_cache = False + + # Initialize bounded cache if cache_size is provided and > 0 + if cache_size and cache_size > 0: + if cache_ttl and cache_ttl > 0: + self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl) + else: + self._dags = LRUCache(maxsize=cache_size) + # Lock required: cachetools caches are NOT thread-safe + # (LRU reordering and TTL cleanup mutate internal linked lists) + self._lock = RLock() + self._use_cache = True def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None: + """Read and optionally cache a SerializedDAG from a SerializedDagModel.""" serdag.load_op_links = self.load_op_links - if dag := serdag.dag: + dag = serdag.dag + if not dag: + return None + if self._lock: + with self._lock: + self._dags[serdag.dag_version_id] = dag + Stats.gauge("api_server.dag_bag.cache_size", len(self._dags)) + else: self._dags[serdag.dag_version_id] = dag return dag def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | None: - if dag := self._dags.get(version_id): + # Check cache first + if self._lock: + with self._lock: + dag = self._dags.get(version_id) + else: + dag = self._dags.get(version_id) + + if dag: + if self._use_cache: + Stats.incr("api_server.dag_bag.cache_hit") return dag + + if self._use_cache: + Stats.incr("api_server.dag_bag.cache_miss") + dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) if not dag_version: return None if not (serdag := dag_version.serialized_dag): return None + + # Double-checked locking: another thread may have cached it while we queried DB + if self._lock: + with self._lock: + if dag := self._dags.get(version_id): + return dag return self._read_dag(serdag) + def clear_cache(self) -> int: + """ + Clear all cached DAGs. + + :return: Number of entries cleared from the cache. + """ + if self._lock: + with self._lock: + count = len(self._dags) + self._dags.clear() + else: + count = len(self._dags) + self._dags.clear() + + if self._use_cache: + Stats.incr("api_server.dag_bag.cache_clear") + return count + @staticmethod def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> DagVersion | None: if not dag_run.bundle_version: @@ -86,11 +167,17 @@ def get_dag_for_run(self, dag_run: DagRun, session: Session) -> SerializedDAG | return None def iter_all_latest_version_dags(self, *, session: Session) -> Generator[SerializedDAG, None, None]: - """Walk through all latest version dags available in the database.""" + """ + Walk through all latest version dags available in the database. + + Note: This method does NOT cache the DAGs to avoid cache thrashing when + iterating over many DAGs. Each DAG is deserialized fresh from the database. + """ from airflow.models.serialized_dag import SerializedDagModel for sdm in session.scalars(select(SerializedDagModel)): - if dag := self._read_dag(sdm): + sdm.load_op_links = self.load_op_links + if dag := sdm.dag: yield dag def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> SerializedDAG | None: diff --git a/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py b/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py index 27f34064e5f77..f8681b2ad4bc8 100644 --- a/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py +++ b/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py @@ -19,6 +19,7 @@ from unittest import mock import pytest +from cachetools import LRUCache, TTLCache from airflow.api_fastapi.app import purge_cached_app from airflow.sdk import BaseOperator @@ -82,3 +83,50 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c assert resp2.status_code == 200 assert self.dagbag_call_counter["count"] == 1 + + +class TestCreateDagBag: + """Tests for create_dag_bag() function.""" + + @mock.patch("airflow.api_fastapi.common.dagbag.conf") + def test_creates_cached_dag_bag_by_default(self, mock_conf): + """Test that create_dag_bag creates a cached DBDagBag by default.""" + from airflow.api_fastapi.common.dagbag import create_dag_bag + + mock_conf.getint.side_effect = lambda section, key, fallback: { + "dag_cache_size": 64, + "dag_cache_ttl": 3600, + }.get(key, fallback) + + dag_bag = create_dag_bag() + assert dag_bag._use_cache is True + assert isinstance(dag_bag._dags, TTLCache) + + @mock.patch("airflow.api_fastapi.common.dagbag.conf") + def test_creates_unbounded_dag_bag_when_cache_size_zero(self, mock_conf): + """Test that create_dag_bag creates unbounded DBDagBag when cache_size is 0.""" + from airflow.api_fastapi.common.dagbag import create_dag_bag + + mock_conf.getint.side_effect = lambda section, key, fallback: { + "dag_cache_size": 0, + "dag_cache_ttl": 3600, + }.get(key, fallback) + + dag_bag = create_dag_bag() + assert dag_bag._use_cache is False + assert isinstance(dag_bag._dags, dict) + assert dag_bag._lock is None + + @mock.patch("airflow.api_fastapi.common.dagbag.conf") + def test_creates_lru_only_dag_bag_when_ttl_zero(self, mock_conf): + """Test that create_dag_bag creates LRU-only cache when cache_ttl is 0.""" + from airflow.api_fastapi.common.dagbag import create_dag_bag + + mock_conf.getint.side_effect = lambda section, key, fallback: { + "dag_cache_size": 64, + "dag_cache_ttl": 0, + }.get(key, fallback) + + dag_bag = create_dag_bag() + assert dag_bag._use_cache is True + assert isinstance(dag_bag._dags, LRUCache) diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 3b5b98877262e..e49a069872d78 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -16,13 +16,237 @@ # under the License. from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, patch + import pytest +import time_machine +from cachetools import LRUCache, TTLCache + +from airflow.models.dagbag import DBDagBag pytestmark = pytest.mark.db_test -# This file previously contained tests for DagBag functionality, but those tests -# have been moved to airflow-core/tests/unit/dag_processing/test_dagbag.py to match -# the source code reorganization where DagBag moved from models to dag_processing. -# -# Tests for models-specific functionality (DBDagBag, DagPriorityParsingRequest, etc.) -# would remain in this file, but currently no such tests exist. + +class TestDBDagBagCache: + """Tests for DBDagBag optional caching behavior.""" + + def test_no_caching_by_default(self): + """Test that DBDagBag uses a simple dict without caching by default.""" + dag_bag = DBDagBag() + assert dag_bag._use_cache is False + assert isinstance(dag_bag._dags, dict) + assert dag_bag._lock is None + + def test_lru_cache_enabled_with_cache_size(self): + """Test that LRU cache is enabled when cache_size is provided.""" + dag_bag = DBDagBag(cache_size=10) + assert dag_bag._use_cache is True + assert isinstance(dag_bag._dags, LRUCache) + assert dag_bag._lock is not None + + def test_ttl_cache_enabled_with_cache_size_and_ttl(self): + """Test that TTL cache is enabled when both cache_size and cache_ttl are provided.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + assert dag_bag._use_cache is True + assert isinstance(dag_bag._dags, TTLCache) + assert dag_bag._lock is not None + + def test_zero_cache_size_uses_unbounded_dict(self): + """Test that cache_size=0 uses unbounded dict (same as no caching).""" + dag_bag = DBDagBag(cache_size=0, cache_ttl=60) + assert dag_bag._use_cache is False + assert isinstance(dag_bag._dags, dict) + assert dag_bag._lock is None + + def test_clear_cache_with_caching(self): + """Test clear_cache() with caching enabled.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + + # Add some mock DAGs to cache + mock_dag = MagicMock() + dag_bag._dags["version_1"] = mock_dag + dag_bag._dags["version_2"] = mock_dag + assert len(dag_bag._dags) == 2 + + # Clear cache + count = dag_bag.clear_cache() + assert count == 2 + assert len(dag_bag._dags) == 0 + + def test_clear_cache_without_caching(self): + """Test clear_cache() without caching enabled.""" + dag_bag = DBDagBag() + + # Add some mock DAGs + mock_dag = MagicMock() + dag_bag._dags["version_1"] = mock_dag + assert len(dag_bag._dags) == 1 + + # Clear cache + count = dag_bag.clear_cache() + assert count == 1 + assert len(dag_bag._dags) == 0 + + def test_ttl_cache_expiry(self): + """Test that cached DAGs expire after TTL.""" + with time_machine.travel("2025-01-01 00:00:00", tick=False): + dag_bag = DBDagBag(cache_size=10, cache_ttl=1) # 1 second TTL + + # Add a mock DAG to cache + mock_dag = MagicMock() + dag_bag._dags["test_version_id"] = mock_dag + assert "test_version_id" in dag_bag._dags + + # Jump ahead beyond TTL + with time_machine.travel("2025-01-01 00:00:02", tick=False): + # Cache should have expired + assert dag_bag._dags.get("test_version_id") is None + + def test_lru_eviction(self): + """Test that LRU eviction works when cache is full.""" + dag_bag = DBDagBag(cache_size=2) + + # Add 3 DAGs - first one should be evicted + dag_bag._dags["version_1"] = MagicMock() + dag_bag._dags["version_2"] = MagicMock() + dag_bag._dags["version_3"] = MagicMock() + + # version_1 should be evicted (LRU) + assert dag_bag._dags.get("version_1") is None + assert dag_bag._dags.get("version_2") is not None + assert dag_bag._dags.get("version_3") is not None + + def test_thread_safety_with_caching(self): + """Test concurrent access doesn't cause race conditions with caching enabled.""" + dag_bag = DBDagBag(cache_size=100, cache_ttl=60) + errors = [] + mock_session = MagicMock() + + def make_dag_version(version_id: str) -> MagicMock: + serdag = MagicMock() + serdag.dag = MagicMock() + serdag.dag_version_id = version_id + return MagicMock(serialized_dag=serdag) + + def get_dag_version(model, version_id, options=None): + return make_dag_version(version_id) + + mock_session.get.side_effect = get_dag_version + + def access_cache(i): + try: + dag_bag._get_dag(f"version_{i % 5}", mock_session) + except Exception as e: + errors.append(e) + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(access_cache, i) for i in range(100)] + for f in futures: + f.result() + + assert not errors + + def test_read_dag_stores_in_bounded_cache(self): + """Test that _read_dag stores DAG in bounded cache when cache_size > 0.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + + mock_sdm = MagicMock() + mock_sdm.dag = MagicMock() + mock_sdm.dag_version_id = "test_version" + + result = dag_bag._read_dag(mock_sdm) + + assert result == mock_sdm.dag + assert "test_version" in dag_bag._dags + assert dag_bag._lock is not None # lock exists for bounded cache + + def test_read_dag_stores_in_unbounded_dict(self): + """Test that _read_dag stores DAG in unbounded dict when no cache_size.""" + dag_bag = DBDagBag() + + mock_sdm = MagicMock() + mock_sdm.dag = MagicMock() + mock_sdm.dag_version_id = "test_version" + + result = dag_bag._read_dag(mock_sdm) + + assert result == mock_sdm.dag + assert "test_version" in dag_bag._dags + assert dag_bag._lock is None # no lock for unbounded dict + + def test_cache_size_zero_stores_in_unbounded_dict(self): + """Test that cache_size=0 stores DAGs in unbounded dict (same as default).""" + dag_bag = DBDagBag(cache_size=0) + + mock_sdm = MagicMock() + mock_sdm.dag = MagicMock() + mock_sdm.dag_version_id = "test_version" + + result = dag_bag._read_dag(mock_sdm) + + assert result == mock_sdm.dag + # cache_size=0 uses unbounded dict, so DAGs are still stored + assert "test_version" in dag_bag._dags + + def test_iter_all_latest_version_dags_does_not_cache(self): + """Test that iter_all_latest_version_dags does not cache to prevent thrashing.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + + # Create mock session and SerializedDagModel + mock_session = MagicMock() + mock_sdm = MagicMock() + mock_sdm.dag = MagicMock() + mock_sdm.dag_version_id = "test_version" + mock_session.scalars.return_value = [mock_sdm] + + # Iterate through DAGs + list(dag_bag.iter_all_latest_version_dags(session=mock_session)) + + # Cache should be empty - iter doesn't cache to prevent thrashing + assert len(dag_bag._dags) == 0 + + @patch("airflow.models.dagbag.Stats") + def test_cache_hit_metric_emitted(self, mock_stats): + """Test that cache hit metric is emitted when caching is enabled.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + mock_session = MagicMock() + dag_bag._dags["test_version"] = MagicMock() + + dag_bag._get_dag("test_version", mock_session) + + mock_stats.incr.assert_called_with("api_server.dag_bag.cache_hit") + + @patch("airflow.models.dagbag.Stats") + def test_cache_miss_metric_emitted(self, mock_stats): + """Test that cache miss metric is emitted when caching is enabled.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + mock_session = MagicMock() + mock_session.get.return_value = None + + dag_bag._get_dag("missing_version", mock_session) + + mock_stats.incr.assert_called_with("api_server.dag_bag.cache_miss") + + @patch("airflow.models.dagbag.Stats") + def test_cache_clear_metric_emitted(self, mock_stats): + """Test that cache clear metric is emitted when caching is enabled.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + dag_bag._dags["test_version"] = MagicMock() + + dag_bag.clear_cache() + + mock_stats.incr.assert_called_with("api_server.dag_bag.cache_clear") + + @patch("airflow.models.dagbag.Stats") + def test_cache_size_gauge_emitted(self, mock_stats): + """Test that cache size gauge is emitted when a DAG is cached.""" + dag_bag = DBDagBag(cache_size=10, cache_ttl=60) + mock_serdag = MagicMock() + mock_serdag.dag_version_id = "test_version_1" + mock_serdag.dag = MagicMock() + mock_serdag.load_op_links = True + + dag_bag._read_dag(mock_serdag) + + mock_stats.gauge.assert_called_with("api_server.dag_bag.cache_size", 1) diff --git a/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml b/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml index e171980c41d71..e6092ea4f14b5 100644 --- a/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml +++ b/shared/observability/src/airflow_shared/observability/metrics/metrics_template.yaml @@ -262,6 +262,24 @@ metrics: legacy_name: "-" name_variables: [] + - name: "api_server.dag_bag.cache_hit" + description: "Number of cache hits when retrieving SerializedDAG from DBDagBag in the API server" + type: "counter" + legacy_name: "-" + name_variables: [] + + - name: "api_server.dag_bag.cache_miss" + description: "Number of cache misses when retrieving SerializedDAG from DBDagBag in the API server" + type: "counter" + legacy_name: "-" + name_variables: [] + + - name: "api_server.dag_bag.cache_clear" + description: "Number of times the DBDagBag cache was cleared in the API server" + type: "counter" + legacy_name: "-" + name_variables: [] + # ========== # Gauges # ========== @@ -271,6 +289,12 @@ metrics: legacy_name: "-" name_variables: [] + - name: "api_server.dag_bag.cache_size" + description: "Current number of SerializedDAG objects cached in the API server's DBDagBag" + type: "gauge" + legacy_name: "-" + name_variables: [] + - name: "dag_processing.import_errors" description: "Number of errors from trying to parse Dag files" type: "gauge"