Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ version = "3.2.0"

dependencies = [
"a2wsgi>=1.10.8",
"cachetools>=5.0.0",
# aiosqlite 0.22.0 has a problem with hanging pytest sessions and we excluded it
# See https://github.com/omnilib/aiosqlite/issues/369
# It seems that while our test issues are fixed in 0.22.1, sqlalchemy 2 itself
Expand Down
25 changes: 23 additions & 2 deletions airflow-core/src/airflow/api_fastapi/common/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,42 @@
# under the License.
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Request, status
from sqlalchemy.orm import Session

from airflow.configuration import conf
from airflow.models.dagbag import DBDagBag

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.serialization.definitions.dag import SerializedDAG

log = logging.getLogger(__name__)


def create_dag_bag() -> DBDagBag:
"""Create DagBag to retrieve DAGs from the database."""
return DBDagBag()
"""Create DagBag with configurable LRU+TTL caching for API server usage."""
cache_size = conf.getint("api", "dag_cache_size", fallback=64)
cache_ttl_config = conf.getint("api", "dag_cache_ttl", fallback=3600)

if cache_size < 0:
log.warning("dag_cache_size must be >= 0, disabling cache")
cache_size = 0
if cache_ttl_config < 0:
log.warning("dag_cache_ttl must be >= 0, disabling TTL")
cache_ttl_config = 0

# Disable caching if cache_size is 0
if cache_size <= 0:
return DBDagBag(cache_size=0)

# Disable TTL if cache_ttl is 0
cache_ttl: int | None = cache_ttl_config if cache_ttl_config > 0 else None

return DBDagBag(cache_size=cache_size, cache_ttl=cache_ttl)


def dag_bag_from_app(request: Request) -> DBDagBag:
Expand Down
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/cli/commands/api_server_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _run_api_server(args, apps: str, num_workers: int, worker_timeout: int, prox
# Control access log based on uvicorn log level - disable for ERROR and above
access_log_enabled = uvicorn_log_level not in ("error", "critical", "fatal")

# Worker recycling: restart workers after N requests to prevent memory accumulation
worker_max_requests = conf.getint("api", "worker_max_requests", fallback=10000)

uvicorn_kwargs = {
"host": args.host,
"port": args.port,
Expand All @@ -95,6 +98,8 @@ def _run_api_server(args, apps: str, num_workers: int, worker_timeout: int, prox
"log_level": uvicorn_log_level,
"proxy_headers": proxy_headers,
}
if worker_max_requests > 0:
uvicorn_kwargs["limit_max_requests"] = worker_max_requests
# Only set the log_config if it is provided, otherwise use the default uvicorn logging configuration.
if args.log_config and args.log_config != "-":
# The [api/log_config] is migrated from [api/access_logfile] and [api/access_logfile] defaults to "-" for stdout for Gunicorn.
Expand Down
27 changes: 27 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,33 @@ api:
type: string
example: ~
default: "False"
dag_cache_size:
description: |
Size of the LRU cache for SerializedDAG objects in the API server.
Set to 0 to disable caching. When caching is enabled, DAGs are cached
in memory to avoid repeated database queries and deserialization.
version_added: 3.2.0
type: integer
example: ~
default: "64"
dag_cache_ttl:
description: |
Time-to-live (seconds) for cached SerializedDAG objects in the API server.
After this time, cached DAGs will be re-fetched from the database on next access.
Set to 0 to disable TTL (cache entries will only be evicted by LRU policy).
version_added: 3.2.0
type: integer
example: ~
default: "3600"
worker_max_requests:
description: |
Maximum number of requests a worker will handle before being recycled.
This helps prevent memory growth from long-running processes by periodically
restarting workers. Set to 0 to disable worker recycling.
version_added: 3.2.0
type: integer
example: ~
default: "10000"
base_url:
description: |
The base url of the API server. Airflow cannot guess what domain or CNAME you are using.
Expand Down
101 changes: 94 additions & 7 deletions airflow-core/src/airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
from __future__ import annotations

import hashlib
import logging
from collections.abc import MutableMapping
from threading import RLock
from typing import TYPE_CHECKING, Any

from cachetools import LRUCache, TTLCache
from sqlalchemy import String, inspect, select
from sqlalchemy.orm import Mapped, joinedload
from sqlalchemy.orm.attributes import NO_VALUE

from airflow._shared.observability.metrics.stats import Stats
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.utils.sqlalchemy import mapped_column
Expand All @@ -37,34 +42,110 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.definitions.dag import SerializedDAG

log = logging.getLogger(__name__)


class DBDagBag:
"""
Internal class for retrieving and caching dags in the scheduler.
Internal class for retrieving dags from the database.

Optionally supports LRU+TTL caching when cache_size is provided.
The scheduler uses this without caching, while the API server can
enable caching via configuration.

:meta private:
"""

def __init__(self, load_op_links: bool = True) -> None:
self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag
def __init__(
self,
load_op_links: bool = True,
cache_size: int | None = None,
cache_ttl: int | None = None,
) -> None:
"""
Initialize DBDagBag.

:param load_op_links: Should the extra operator link be loaded when de-serializing the DAG?
:param cache_size: Size of LRU cache. If None or 0, uses unbounded dict (no eviction).
:param cache_ttl: Time-to-live for cache entries in seconds. If None or 0, no TTL (LRU only).
"""
self.load_op_links = load_op_links
self._dags: MutableMapping[str, SerializedDAG] = {}
self._lock: RLock | None = None
self._use_cache = False

# Initialize bounded cache if cache_size is provided and > 0
if cache_size and cache_size > 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure would it be better to use existed _disable_cache as condition?

Suggested change
if cache_size and cache_size > 0:
if not self._disable_cache:

if cache_ttl and cache_ttl > 0:
self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
else:
self._dags = LRUCache(maxsize=cache_size)
# Lock required: cachetools caches are NOT thread-safe
# (LRU reordering and TTL cleanup mutate internal linked lists)
self._lock = RLock()
self._use_cache = True

def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
"""Read and optionally cache a SerializedDAG from a SerializedDagModel."""
serdag.load_op_links = self.load_op_links
if dag := serdag.dag:
dag = serdag.dag
if not dag:
return None
if self._lock:
with self._lock:
self._dags[serdag.dag_version_id] = dag
Stats.gauge("api_server.dag_bag.cache_size", len(self._dags))
else:
self._dags[serdag.dag_version_id] = dag
return dag

def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | None:
if dag := self._dags.get(version_id):
# Check cache first
if self._lock:
with self._lock:
dag = self._dags.get(version_id)
else:
dag = self._dags.get(version_id)

if dag:
if self._use_cache:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we could consolidate _disable_cache and _use_cache as same variable.

Stats.incr("api_server.dag_bag.cache_hit")
return dag

if self._use_cache:
Stats.incr("api_server.dag_bag.cache_miss")

dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)])
if not dag_version:
return None
if not (serdag := dag_version.serialized_dag):
return None

# Double-checked locking: another thread may have cached it while we queried DB
if self._lock:
with self._lock:
if dag := self._dags.get(version_id):
return dag
return self._read_dag(serdag)

def clear_cache(self) -> int:
"""
Clear all cached DAGs.

:return: Number of entries cleared from the cache.
"""
if self._lock:
with self._lock:
count = len(self._dags)
self._dags.clear()
else:
count = len(self._dags)
self._dags.clear()

if self._use_cache:
Stats.incr("api_server.dag_bag.cache_clear")
return count

@staticmethod
def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> DagVersion | None:
if not dag_run.bundle_version:
Expand All @@ -86,11 +167,17 @@ def get_dag_for_run(self, dag_run: DagRun, session: Session) -> SerializedDAG |
return None

def iter_all_latest_version_dags(self, *, session: Session) -> Generator[SerializedDAG, None, None]:
"""Walk through all latest version dags available in the database."""
"""
Walk through all latest version dags available in the database.

Note: This method does NOT cache the DAGs to avoid cache thrashing when
iterating over many DAGs. Each DAG is deserialized fresh from the database.
"""
from airflow.models.serialized_dag import SerializedDagModel

for sdm in session.scalars(select(SerializedDagModel)):
if dag := self._read_dag(sdm):
sdm.load_op_links = self.load_op_links
if dag := sdm.dag:
yield dag

def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> SerializedDAG | None:
Expand Down
48 changes: 48 additions & 0 deletions airflow-core/tests/unit/api_fastapi/common/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest import mock

import pytest
from cachetools import LRUCache, TTLCache

from airflow.api_fastapi.app import purge_cached_app
from airflow.sdk import BaseOperator
Expand Down Expand Up @@ -82,3 +83,50 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c
assert resp2.status_code == 200

assert self.dagbag_call_counter["count"] == 1


class TestCreateDagBag:
"""Tests for create_dag_bag() function."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although not necessary, we could consolidate these test methods using pytest.mark.parameterize with dag_cache_size, dag_cache_ttl, expected_class.

@mock.patch("airflow.api_fastapi.common.dagbag.conf")
def test_creates_cached_dag_bag_by_default(self, mock_conf):
"""Test that create_dag_bag creates a cached DBDagBag by default."""
from airflow.api_fastapi.common.dagbag import create_dag_bag

mock_conf.getint.side_effect = lambda section, key, fallback: {
"dag_cache_size": 64,
"dag_cache_ttl": 3600,
}.get(key, fallback)

dag_bag = create_dag_bag()
assert dag_bag._use_cache is True
assert isinstance(dag_bag._dags, TTLCache)

@mock.patch("airflow.api_fastapi.common.dagbag.conf")
def test_creates_unbounded_dag_bag_when_cache_size_zero(self, mock_conf):
"""Test that create_dag_bag creates unbounded DBDagBag when cache_size is 0."""
from airflow.api_fastapi.common.dagbag import create_dag_bag

mock_conf.getint.side_effect = lambda section, key, fallback: {
"dag_cache_size": 0,
"dag_cache_ttl": 3600,
}.get(key, fallback)

dag_bag = create_dag_bag()
assert dag_bag._use_cache is False
assert isinstance(dag_bag._dags, dict)
assert dag_bag._lock is None

@mock.patch("airflow.api_fastapi.common.dagbag.conf")
def test_creates_lru_only_dag_bag_when_ttl_zero(self, mock_conf):
"""Test that create_dag_bag creates LRU-only cache when cache_ttl is 0."""
from airflow.api_fastapi.common.dagbag import create_dag_bag

mock_conf.getint.side_effect = lambda section, key, fallback: {
"dag_cache_size": 64,
"dag_cache_ttl": 0,
}.get(key, fallback)

dag_bag = create_dag_bag()
assert dag_bag._use_cache is True
assert isinstance(dag_bag._dags, LRUCache)
Loading
Loading