Skip to content

Commit ba20d06

Browse files
authored
Clean up typing with max_execution_date query builder (#36958)
1 parent 63e93d7 commit ba20d06

2 files changed

Lines changed: 7 additions & 8 deletions

File tree

airflow/models/dag.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,8 +3049,8 @@ def bulk_write_to_db(
30493049
)
30503050
query = with_row_locks(query, of=DagModel, session=session)
30513051
orm_dags: list[DagModel] = session.scalars(query).unique().all()
3052-
existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags}
3053-
missing_dag_ids = dag_ids.difference(existing_dags)
3052+
existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags}
3053+
missing_dag_ids = dag_ids.difference(existing_dags.keys())
30543054

30553055
for missing_dag_id in missing_dag_ids:
30563056
orm_dag = DagModel(dag_id=missing_dag_id)
@@ -3067,7 +3067,7 @@ def bulk_write_to_db(
30673067
# Skip these queries entirely if no DAGs can be scheduled to save time.
30683068
if any(dag.timetable.can_be_scheduled for dag in dags):
30693069
# Get the latest automated dag run for each existing dag as a single query (avoid n+1 query)
3070-
query = cls._get_latest_runs_query(existing_dags, session)
3070+
query = cls._get_latest_runs_query(dags=list(existing_dags.keys()))
30713071
latest_runs = {run.dag_id: run for run in session.scalars(query)}
30723072

30733073
# Get number of active dagruns for all dags we are processing as a single query.
@@ -3240,16 +3240,15 @@ def bulk_write_to_db(
32403240
cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)
32413241

32423242
@classmethod
3243-
def _get_latest_runs_query(cls, dags, session) -> Query:
3243+
def _get_latest_runs_query(cls, dags: list[str]) -> Query:
32443244
"""
32453245
Query the database to retrieve the last automated run for each dag.
32463246
32473247
:param dags: dags to query
3248-
:param session: sqlalchemy session object
32493248
"""
32503249
if len(dags) == 1:
32513250
# Index optimized fast path to avoid more complicated & slower groupby queryplan
3252-
existing_dag_id = list(dags)[0].dag_id
3251+
existing_dag_id = dags[0]
32533252
last_automated_runs_subq = (
32543253
select(func.max(DagRun.execution_date).label("max_execution_date"))
32553254
.where(

tests/models/test_dag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,7 +4140,7 @@ def test_validate_setup_teardown_trigger_rule(self):
41404140
def test_get_latest_runs_query_one_dag(dag_maker, session):
41414141
with dag_maker(dag_id="dag1") as dag1:
41424142
...
4143-
query = DAG._get_latest_runs_query(dags=[dag1], session=session)
4143+
query = DAG._get_latest_runs_query(dags=[dag1.dag_id])
41444144
actual = [x.strip() for x in str(query.compile()).splitlines()]
41454145
expected = [
41464146
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
@@ -4157,7 +4157,7 @@ def test_get_latest_runs_query_two_dags(dag_maker, session):
41574157
...
41584158
with dag_maker(dag_id="dag2") as dag2:
41594159
...
4160-
query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session)
4160+
query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id])
41614161
actual = [x.strip() for x in str(query.compile()).splitlines()]
41624162
print("\n".join(actual))
41634163
expected = [

0 commit comments

Comments
 (0)