Skip to content

Commit 1a85446

Browse files
authored
Add more type hints to the code base (#30503)
* Fully type Pool Also fix a bug where create_or_update_pool silently fails when an empty name is given. An error is raised instead now. * Add types to 'airflow dags' * Add types to 'airflow task' and 'airflow job' * Improve KubernetesExecutor typing * Add types to BackfillJob This triggers an existing typing bug that pickle_id is incorrectly typed as str in executors, while it should be int in practice. This is fixed to keep things straight. * Add types to job classes * Fix missing DagModel case in SchedulerJob * Add types to DagCode * Add more types to DagRun * Add types to serialized DAG model * Add more types to TaskInstance and TaskReschedule * Add types to Trigger * Add types to MetastoreBackend * Add types to external task sensor * Add types to AirflowSecurityManager This uncovers a couple of incorrect type hints in the base SecurityManager (in fab_security), which are also fixed. * Add types to views This slightly improves how view functions are typechecked and should prevent some trivial bugs.
1 parent 0b83f06 commit 1a85446

32 files changed

Lines changed: 562 additions & 390 deletions

airflow/api/common/delete_dag.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@
2121
import logging
2222

2323
from sqlalchemy import and_, or_
24+
from sqlalchemy.orm import Session
2425

2526
from airflow import models
2627
from airflow.exceptions import AirflowException, DagNotFound
2728
from airflow.models import DagModel, TaskFail
2829
from airflow.models.serialized_dag import SerializedDagModel
2930
from airflow.utils.db import get_sqla_model_classes
30-
from airflow.utils.session import provide_session
31+
from airflow.utils.session import NEW_SESSION, provide_session
3132
from airflow.utils.state import State
3233

3334
log = logging.getLogger(__name__)
3435

3536

3637
@provide_session
37-
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int:
38+
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session = NEW_SESSION) -> int:
3839
"""
3940
Delete a DAG by a dag_id.
4041

airflow/api/common/experimental/pool.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
from __future__ import annotations
2020

2121
from deprecated import deprecated
22+
from sqlalchemy.orm import Session
2223

2324
from airflow.exceptions import AirflowBadRequest, PoolNotFound
2425
from airflow.models import Pool
25-
from airflow.utils.session import provide_session
26+
from airflow.utils.session import NEW_SESSION, provide_session
2627

2728

2829
@deprecated(reason="Use Pool.get_pool() instead", version="2.2.4")
2930
@provide_session
30-
def get_pool(name, session=None):
31+
def get_pool(name, session: Session = NEW_SESSION):
3132
"""Get pool by a given name."""
3233
if not (name and name.strip()):
3334
raise AirflowBadRequest("Pool name shouldn't be empty")
@@ -41,14 +42,14 @@ def get_pool(name, session=None):
4142

4243
@deprecated(reason="Use Pool.get_pools() instead", version="2.2.4")
4344
@provide_session
44-
def get_pools(session=None):
45+
def get_pools(session: Session = NEW_SESSION):
4546
"""Get all pools."""
4647
return session.query(Pool).all()
4748

4849

4950
@deprecated(reason="Use Pool.create_pool() instead", version="2.2.4")
5051
@provide_session
51-
def create_pool(name, slots, description, session=None):
52+
def create_pool(name, slots, description, session: Session = NEW_SESSION):
5253
"""Create a pool with given parameters."""
5354
if not (name and name.strip()):
5455
raise AirflowBadRequest("Pool name shouldn't be empty")
@@ -79,7 +80,7 @@ def create_pool(name, slots, description, session=None):
7980

8081
@deprecated(reason="Use Pool.delete_pool() instead", version="2.2.4")
8182
@provide_session
82-
def delete_pool(name, session=None):
83+
def delete_pool(name, session: Session = NEW_SESSION):
8384
"""Delete pool by a given name."""
8485
if not (name and name.strip()):
8586
raise AirflowBadRequest("Pool name shouldn't be empty")

airflow/cli/commands/dag_command.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import errno
2222
import json
2323
import logging
24+
import operator
2425
import signal
2526
import subprocess
2627
import sys
28+
import warnings
2729

2830
from graphviz.dot import Dot
2931
from sqlalchemy.orm import Session
@@ -47,33 +49,7 @@
4749
log = logging.getLogger(__name__)
4850

4951

50-
@cli_utils.action_cli
51-
def dag_backfill(args, dag=None):
52-
"""Creates backfill job or dry run for a DAG or list of DAGs using regex."""
53-
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
54-
55-
signal.signal(signal.SIGTERM, sigint_handler)
56-
57-
import warnings
58-
59-
warnings.warn(
60-
"--ignore-first-depends-on-past is deprecated as the value is always set to True",
61-
category=RemovedInAirflow3Warning,
62-
)
63-
64-
if args.ignore_first_depends_on_past is False:
65-
args.ignore_first_depends_on_past = True
66-
67-
if not args.start_date and not args.end_date:
68-
raise AirflowException("Provide a start_date and/or end_date")
69-
70-
if not dag:
71-
dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
72-
else:
73-
dags = dag if type(dag) == list else [dag]
74-
75-
dags.sort(key=lambda d: d.dag_id)
76-
52+
def _run_dag_backfill(dags: list[DAG], args) -> None:
7753
# If only one date is passed, using same as start and end
7854
args.end_date = args.end_date or args.start_date
7955
args.start_date = args.start_date or args.end_date
@@ -133,12 +109,39 @@ def dag_backfill(args, dag=None):
133109
print(str(vr))
134110
sys.exit(1)
135111

112+
113+
@cli_utils.action_cli
114+
def dag_backfill(args, dag: list[DAG] | DAG | None = None) -> None:
115+
"""Creates backfill job or dry run for a DAG or list of DAGs using regex."""
116+
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
117+
signal.signal(signal.SIGTERM, sigint_handler)
118+
warnings.warn(
119+
"--ignore-first-depends-on-past is deprecated as the value is always set to True",
120+
category=RemovedInAirflow3Warning,
121+
)
122+
123+
if args.ignore_first_depends_on_past is False:
124+
args.ignore_first_depends_on_past = True
125+
126+
if not args.start_date and not args.end_date:
127+
raise AirflowException("Provide a start_date and/or end_date")
128+
129+
if not dag:
130+
dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
131+
elif isinstance(dag, list):
132+
dags = dag
133+
else:
134+
dags = [dag]
135+
del dag
136+
137+
dags.sort(key=lambda d: d.dag_id)
138+
_run_dag_backfill(dags, args)
136139
if len(dags) > 1:
137140
log.info("All of the backfills are done.")
138141

139142

140143
@cli_utils.action_cli
141-
def dag_trigger(args):
144+
def dag_trigger(args) -> None:
142145
"""Creates a dag run for the specified dag."""
143146
api_client = get_current_api_client()
144147
try:
@@ -159,7 +162,7 @@ def dag_trigger(args):
159162

160163

161164
@cli_utils.action_cli
162-
def dag_delete(args):
165+
def dag_delete(args) -> None:
163166
"""Deletes all DB records related to the specified dag."""
164167
api_client = get_current_api_client()
165168
if (
@@ -177,18 +180,18 @@ def dag_delete(args):
177180

178181

179182
@cli_utils.action_cli
180-
def dag_pause(args):
183+
def dag_pause(args) -> None:
181184
"""Pauses a DAG."""
182185
set_is_paused(True, args)
183186

184187

185188
@cli_utils.action_cli
186-
def dag_unpause(args):
189+
def dag_unpause(args) -> None:
187190
"""Unpauses a DAG."""
188191
set_is_paused(False, args)
189192

190193

191-
def set_is_paused(is_paused, args):
194+
def set_is_paused(is_paused: bool, args) -> None:
192195
"""Sets is_paused for DAG by a given dag_id."""
193196
dag = DagModel.get_dagmodel(args.dag_id)
194197

@@ -200,7 +203,7 @@ def set_is_paused(is_paused, args):
200203
print(f"Dag: {args.dag_id}, paused: {is_paused}")
201204

202205

203-
def dag_dependencies_show(args):
206+
def dag_dependencies_show(args) -> None:
204207
"""Displays DAG dependencies, save to file or show as imgcat image."""
205208
dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies())
206209
filename = args.save
@@ -219,7 +222,7 @@ def dag_dependencies_show(args):
219222
print(dot.source)
220223

221224

222-
def dag_show(args):
225+
def dag_show(args) -> None:
223226
"""Displays DAG or saves it's graphic representation to the file."""
224227
dag = get_dag(args.subdir, args.dag_id)
225228
dot = render_dag(dag)
@@ -239,7 +242,7 @@ def dag_show(args):
239242
print(dot.source)
240243

241244

242-
def _display_dot_via_imgcat(dot: Dot):
245+
def _display_dot_via_imgcat(dot: Dot) -> None:
243246
data = dot.pipe(format="png")
244247
try:
245248
with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc:
@@ -255,15 +258,15 @@ def _display_dot_via_imgcat(dot: Dot):
255258
raise
256259

257260

258-
def _save_dot_to_file(dot: Dot, filename: str):
261+
def _save_dot_to_file(dot: Dot, filename: str) -> None:
259262
filename_without_ext, _, ext = filename.rpartition(".")
260263
dot.render(filename=filename_without_ext, format=ext, cleanup=True)
261264
print(f"File {filename} saved")
262265

263266

264267
@cli_utils.action_cli
265268
@provide_session
266-
def dag_state(args, session=NEW_SESSION):
269+
def dag_state(args, session: Session = NEW_SESSION) -> None:
267270
"""
268271
Returns the state (and conf if exists) of a DagRun at the command line.
269272
>>> airflow dags state tutorial 2015-01-01T00:00:00.000000
@@ -284,7 +287,7 @@ def dag_state(args, session=NEW_SESSION):
284287

285288

286289
@cli_utils.action_cli
287-
def dag_next_execution(args):
290+
def dag_next_execution(args) -> None:
288291
"""
289292
Returns the next execution datetime of a DAG at the command line.
290293
>>> airflow dags next-execution tutorial
@@ -312,15 +315,15 @@ def print_execution_interval(interval: DataInterval | None):
312315
next_interval = dag.get_next_data_interval(last_parsed_dag)
313316
print_execution_interval(next_interval)
314317

315-
for i in range(1, args.num_executions):
318+
for _ in range(1, args.num_executions):
316319
next_info = dag.next_dagrun_info(next_interval, restricted=False)
317320
next_interval = None if next_info is None else next_info.data_interval
318321
print_execution_interval(next_interval)
319322

320323

321324
@cli_utils.action_cli
322325
@suppress_logs_and_warning
323-
def dag_list_dags(args):
326+
def dag_list_dags(args) -> None:
324327
"""Displays dags with or without stats at the command line."""
325328
dagbag = DagBag(process_subdir(args.subdir))
326329
if dagbag.import_errors:
@@ -332,7 +335,7 @@ def dag_list_dags(args):
332335
file=sys.stderr,
333336
)
334337
AirflowConsole().print_as(
335-
data=sorted(dagbag.dags.values(), key=lambda d: d.dag_id),
338+
data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")),
336339
output=args.output,
337340
mapper=lambda x: {
338341
"dag_id": x.dag_id,
@@ -345,7 +348,7 @@ def dag_list_dags(args):
345348

346349
@cli_utils.action_cli
347350
@suppress_logs_and_warning
348-
def dag_list_import_errors(args):
351+
def dag_list_import_errors(args) -> None:
349352
"""Displays dags with import errors on the command line."""
350353
dagbag = DagBag(process_subdir(args.subdir))
351354
data = []
@@ -359,7 +362,7 @@ def dag_list_import_errors(args):
359362

360363
@cli_utils.action_cli
361364
@suppress_logs_and_warning
362-
def dag_report(args):
365+
def dag_report(args) -> None:
363366
"""Displays dagbag stats at the command line."""
364367
dagbag = DagBag(process_subdir(args.subdir))
365368
AirflowConsole().print_as(
@@ -378,7 +381,7 @@ def dag_report(args):
378381
@cli_utils.action_cli
379382
@suppress_logs_and_warning
380383
@provide_session
381-
def dag_list_jobs(args, dag=None, session=NEW_SESSION):
384+
def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
382385
"""Lists latest n jobs."""
383386
queries = []
384387
if dag:
@@ -408,7 +411,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION):
408411
@cli_utils.action_cli
409412
@suppress_logs_and_warning
410413
@provide_session
411-
def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):
414+
def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
412415
"""Lists dag runs for a given DAG."""
413416
if dag:
414417
args.dag_id = dag.dag_id
@@ -445,7 +448,7 @@ def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):
445448

446449
@provide_session
447450
@cli_utils.action_cli
448-
def dag_test(args, dag=None, session=None):
451+
def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
449452
"""Execute one single DagRun for a given DAG and execution date."""
450453
run_conf = None
451454
if args.conf:
@@ -481,7 +484,7 @@ def dag_test(args, dag=None, session=None):
481484

482485
@provide_session
483486
@cli_utils.action_cli
484-
def dag_reserialize(args, session: Session = NEW_SESSION):
487+
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
485488
"""Serialize a DAG instance."""
486489
session.query(SerializedDagModel).delete(synchronize_session=False)
487490

airflow/cli/commands/jobs_command.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from sqlalchemy.orm import Session
20+
1921
from airflow.jobs.base_job import BaseJob
2022
from airflow.utils.net import get_hostname
21-
from airflow.utils.session import provide_session
23+
from airflow.utils.session import NEW_SESSION, provide_session
2224
from airflow.utils.state import State
2325

2426

2527
@provide_session
26-
def check(args, session=None):
28+
def check(args, session: Session = NEW_SESSION) -> None:
2729
"""Checks if job(s) are still alive."""
2830
if args.allow_multiple and not args.limit > 1:
2931
raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.")

0 commit comments

Comments
 (0)