-
Notifications
You must be signed in to change notification settings - Fork 16.8k
Add configurable LRU+TTL caching for API server DAG retrieval #60804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,12 +18,17 @@ | |
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import logging | ||
| from collections.abc import MutableMapping | ||
| from threading import RLock | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from cachetools import LRUCache, TTLCache | ||
| from sqlalchemy import String, inspect, select | ||
| from sqlalchemy.orm import Mapped, joinedload | ||
| from sqlalchemy.orm.attributes import NO_VALUE | ||
|
|
||
| from airflow._shared.observability.metrics.stats import Stats | ||
| from airflow.models.base import Base, StringID | ||
| from airflow.models.dag_version import DagVersion | ||
| from airflow.utils.sqlalchemy import mapped_column | ||
|
|
@@ -37,34 +42,110 @@ | |
| from airflow.models.serialized_dag import SerializedDagModel | ||
| from airflow.serialization.definitions.dag import SerializedDAG | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DBDagBag: | ||
| """ | ||
| Internal class for retrieving and caching dags in the scheduler. | ||
| Internal class for retrieving dags from the database. | ||
|
|
||
| Optionally supports LRU+TTL caching when cache_size is provided. | ||
| The scheduler uses this without caching, while the API server can | ||
| enable caching via configuration. | ||
|
|
||
| :meta private: | ||
| """ | ||
|
|
||
| def __init__(self, load_op_links: bool = True) -> None: | ||
| self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag | ||
| def __init__( | ||
| self, | ||
| load_op_links: bool = True, | ||
| cache_size: int | None = None, | ||
| cache_ttl: int | None = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize DBDagBag. | ||
|
|
||
| :param load_op_links: Should the extra operator link be loaded when de-serializing the DAG? | ||
| :param cache_size: Size of LRU cache. If None or 0, uses unbounded dict (no eviction). | ||
| :param cache_ttl: Time-to-live for cache entries in seconds. If None or 0, no TTL (LRU only). | ||
| """ | ||
| self.load_op_links = load_op_links | ||
| self._dags: MutableMapping[str, SerializedDAG] = {} | ||
| self._lock: RLock | None = None | ||
| self._use_cache = False | ||
|
|
||
| # Initialize bounded cache if cache_size is provided and > 0 | ||
| if cache_size and cache_size > 0: | ||
| if cache_ttl and cache_ttl > 0: | ||
| self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl) | ||
| else: | ||
| self._dags = LRUCache(maxsize=cache_size) | ||
| # Lock required: cachetools caches are NOT thread-safe | ||
| # (LRU reordering and TTL cleanup mutate internal linked lists) | ||
| self._lock = RLock() | ||
| self._use_cache = True | ||
|
|
||
| def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None: | ||
| """Read and optionally cache a SerializedDAG from a SerializedDagModel.""" | ||
| serdag.load_op_links = self.load_op_links | ||
| if dag := serdag.dag: | ||
| dag = serdag.dag | ||
| if not dag: | ||
| return None | ||
| if self._lock: | ||
| with self._lock: | ||
| self._dags[serdag.dag_version_id] = dag | ||
| Stats.gauge("api_server.dag_bag.cache_size", len(self._dags)) | ||
| else: | ||
| self._dags[serdag.dag_version_id] = dag | ||
| return dag | ||
|
|
||
| def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | None: | ||
| if dag := self._dags.get(version_id): | ||
| # Check cache first | ||
| if self._lock: | ||
| with self._lock: | ||
| dag = self._dags.get(version_id) | ||
| else: | ||
| dag = self._dags.get(version_id) | ||
|
|
||
| if dag: | ||
| if self._use_cache: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we could consolidate |
||
| Stats.incr("api_server.dag_bag.cache_hit") | ||
| return dag | ||
|
|
||
| if self._use_cache: | ||
| Stats.incr("api_server.dag_bag.cache_miss") | ||
|
|
||
| dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) | ||
| if not dag_version: | ||
| return None | ||
| if not (serdag := dag_version.serialized_dag): | ||
| return None | ||
|
|
||
| # Double-checked locking: another thread may have cached it while we queried DB | ||
| if self._lock: | ||
| with self._lock: | ||
| if dag := self._dags.get(version_id): | ||
| return dag | ||
| return self._read_dag(serdag) | ||
|
|
||
| def clear_cache(self) -> int: | ||
| """ | ||
| Clear all cached DAGs. | ||
|
|
||
| :return: Number of entries cleared from the cache. | ||
| """ | ||
| if self._lock: | ||
| with self._lock: | ||
| count = len(self._dags) | ||
| self._dags.clear() | ||
| else: | ||
| count = len(self._dags) | ||
| self._dags.clear() | ||
|
|
||
| if self._use_cache: | ||
| Stats.incr("api_server.dag_bag.cache_clear") | ||
| return count | ||
|
|
||
| @staticmethod | ||
| def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> DagVersion | None: | ||
| if not dag_run.bundle_version: | ||
|
|
@@ -86,11 +167,17 @@ def get_dag_for_run(self, dag_run: DagRun, session: Session) -> SerializedDAG | | |
| return None | ||
|
|
||
| def iter_all_latest_version_dags(self, *, session: Session) -> Generator[SerializedDAG, None, None]: | ||
| """Walk through all latest version dags available in the database.""" | ||
| """ | ||
| Walk through all latest version dags available in the database. | ||
|
|
||
| Note: This method does NOT cache the DAGs to avoid cache thrashing when | ||
| iterating over many DAGs. Each DAG is deserialized fresh from the database. | ||
| """ | ||
| from airflow.models.serialized_dag import SerializedDagModel | ||
|
|
||
| for sdm in session.scalars(select(SerializedDagModel)): | ||
| if dag := self._read_dag(sdm): | ||
| sdm.load_op_links = self.load_op_links | ||
| if dag := sdm.dag: | ||
| yield dag | ||
|
|
||
| def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> SerializedDAG | None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| from unittest import mock | ||
|
|
||
| import pytest | ||
| from cachetools import LRUCache, TTLCache | ||
|
|
||
| from airflow.api_fastapi.app import purge_cached_app | ||
| from airflow.sdk import BaseOperator | ||
|
|
@@ -82,3 +83,50 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c | |
| assert resp2.status_code == 200 | ||
|
|
||
| assert self.dagbag_call_counter["count"] == 1 | ||
|
|
||
|
|
||
| class TestCreateDagBag: | ||
| """Tests for create_dag_bag() function.""" | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although not necessary, we could consolidate these test methods using |
||
| @mock.patch("airflow.api_fastapi.common.dagbag.conf") | ||
| def test_creates_cached_dag_bag_by_default(self, mock_conf): | ||
| """Test that create_dag_bag creates a cached DBDagBag by default.""" | ||
| from airflow.api_fastapi.common.dagbag import create_dag_bag | ||
|
|
||
| mock_conf.getint.side_effect = lambda section, key, fallback: { | ||
| "dag_cache_size": 64, | ||
| "dag_cache_ttl": 3600, | ||
| }.get(key, fallback) | ||
|
|
||
| dag_bag = create_dag_bag() | ||
| assert dag_bag._use_cache is True | ||
| assert isinstance(dag_bag._dags, TTLCache) | ||
|
|
||
| @mock.patch("airflow.api_fastapi.common.dagbag.conf") | ||
| def test_creates_unbounded_dag_bag_when_cache_size_zero(self, mock_conf): | ||
| """Test that create_dag_bag creates unbounded DBDagBag when cache_size is 0.""" | ||
| from airflow.api_fastapi.common.dagbag import create_dag_bag | ||
|
|
||
| mock_conf.getint.side_effect = lambda section, key, fallback: { | ||
| "dag_cache_size": 0, | ||
| "dag_cache_ttl": 3600, | ||
| }.get(key, fallback) | ||
|
|
||
| dag_bag = create_dag_bag() | ||
| assert dag_bag._use_cache is False | ||
| assert isinstance(dag_bag._dags, dict) | ||
| assert dag_bag._lock is None | ||
|
|
||
| @mock.patch("airflow.api_fastapi.common.dagbag.conf") | ||
| def test_creates_lru_only_dag_bag_when_ttl_zero(self, mock_conf): | ||
| """Test that create_dag_bag creates LRU-only cache when cache_ttl is 0.""" | ||
| from airflow.api_fastapi.common.dagbag import create_dag_bag | ||
|
|
||
| mock_conf.getint.side_effect = lambda section, key, fallback: { | ||
| "dag_cache_size": 64, | ||
| "dag_cache_ttl": 0, | ||
| }.get(key, fallback) | ||
|
|
||
| dag_bag = create_dag_bag() | ||
| assert dag_bag._use_cache is True | ||
| assert isinstance(dag_bag._dags, LRUCache) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure would it be better to use existed
_disable_cacheas condition?