Skip to content

Commit 7bde1ac

Browse files
committed
Add SQLA-version-dependent dialect kwarg generator
1 parent 2c2898f commit 7bde1ac

6 files changed

Lines changed: 24 additions & 9 deletions

File tree

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@
7575
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
7676
from airflow.utils.session import NEW_SESSION, create_session, provide_session
7777
from airflow.utils.span_status import SpanStatus
78-
from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, with_row_locks
78+
from airflow.utils.sqlalchemy import (
79+
is_lock_not_available_error,
80+
prohibit_commit,
81+
make_dialect_kwarg,
82+
with_row_locks,
83+
)
7984
from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
8085
from airflow.utils.thread_safe_dict import ThreadSafeDict
8186
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -102,6 +107,8 @@
102107
TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule"
103108
""":meta private:"""
104109

110+
dialect_kwarg = make_dialect_kwarg("mysql")
111+
105112

106113
class SchedulerDagBag:
107114
"""
@@ -400,7 +407,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
400407

401408
query = (
402409
select(TI)
403-
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
410+
.with_hint(TI, "USE INDEX (ti_state)", **dialect_kwarg)
404411
.join(TI.dag_run)
405412
.where(DR.state == DagRunState.RUNNING)
406413
.join(TI.dag_model)
@@ -2237,7 +2244,7 @@ def _find_task_instances_without_heartbeats(self, *, session: Session) -> list[T
22372244
.options(selectinload(TI.dag_model))
22382245
.options(selectinload(TI.dag_run))
22392246
.options(selectinload(TI.dag_version))
2240-
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
2247+
.with_hint(TI, "USE INDEX (ti_state)", **dialect_kwarg)
22412248
.join(DM, TI.dag_id == DM.dag_id)
22422249
.where(
22432250
TI.state.in_((TaskInstanceState.RUNNING, TaskInstanceState.RESTARTING)),

airflow-core/src/airflow/migrations/versions/0036_3_0_0_add_name_field_to_dataset_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
5050
sa.String(length=1500, collation="latin1_general_cs"),
51-
dialect_name="mysql",
51+
"mysql",
5252
)
5353

5454

@@ -128,7 +128,7 @@ def downgrade():
128128
"uri",
129129
type_=sa.String(length=3000).with_variant(
130130
sa.String(length=3000, collation="latin1_general_cs"),
131-
dialect_name="mysql",
131+
"mysql",
132132
),
133133
nullable=False,
134134
)

airflow-core/src/airflow/migrations/versions/0038_3_0_0_add_asset_active.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
4141
sa.String(length=1500, collation="latin1_general_cs"),
42-
dialect_name="mysql",
42+
"mysql",
4343
)
4444

4545

airflow-core/src/airflow/migrations/versions/0039_3_0_0_tweak_assetaliasmodel_to_match_asset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
5353
sa.String(length=1500, collation="latin1_general_cs"),
54-
dialect_name="mysql",
54+
"mysql",
5555
)
5656

5757

@@ -77,7 +77,7 @@ def downgrade():
7777
"name",
7878
type_=sa.String(length=3000).with_variant(
7979
sa.String(length=3000, collation="latin1_general_cs"),
80-
dialect_name="mysql",
80+
"mysql",
8181
),
8282
nullable=False,
8383
)

airflow-core/src/airflow/models/dagrun.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,10 +552,11 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query:
552552
"""
553553
from airflow.models.backfill import BackfillDagRun
554554
from airflow.models.dag import DagModel
555+
from airflow.utils.sqlalchemy import make_dialect_kwarg
555556

556557
query = (
557558
select(cls)
558-
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
559+
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", **make_dialect_kwarg("mysql"))
559560
.where(cls.state == DagRunState.RUNNING)
560561
.join(DagModel, DagModel.dag_id == cls.dag_id)
561562
.join(BackfillDagRun, BackfillDagRun.dag_run_id == DagRun.id, isouter=True)

airflow-core/src/airflow/utils/sqlalchemy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from airflow.utils.timezone import make_naive, utc
3737

3838
if TYPE_CHECKING:
39+
from collections.abc import Iterable
40+
3941
from kubernetes.client.models.v1_pod import V1Pod
4042
from sqlalchemy.exc import OperationalError
4143
from sqlalchemy.orm import Query, Session
@@ -448,3 +450,8 @@ def get_orm_mapper():
448450

449451
def is_sqlalchemy_v1() -> bool:
450452
return version.parse(metadata.version("sqlalchemy")).major == 1
453+
454+
455+
def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]:
456+
"""Create an SQLAlchemy-version-aware dialect keyword argument."""
457+
return {"dialect_name": dialect} if is_sqlalchemy_v1() else {"dialect_names": (dialect,)}

0 commit comments

Comments
 (0)