Skip to content

Commit b882246

Browse files
AIP-84 Introduce SessionDep and AsyncSessionDep construct (#44461)
1 parent b2d2bcb commit b882246

27 files changed

Lines changed: 140 additions & 200 deletions

airflow/api_fastapi/common/db/common.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,29 @@
2323
from __future__ import annotations
2424

2525
from collections.abc import Sequence
26-
from typing import TYPE_CHECKING, Literal, overload
26+
from typing import TYPE_CHECKING, Annotated, Literal, overload
2727

28+
from fastapi import Depends
2829
from sqlalchemy.ext.asyncio import AsyncSession
30+
from sqlalchemy.orm import Session
2931

3032
from airflow.utils.db import get_query_count, get_query_count_async
3133
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session
3234

3335
if TYPE_CHECKING:
34-
from sqlalchemy.orm import Session
3536
from sqlalchemy.sql import Select
3637

3738
from airflow.api_fastapi.common.parameters import BaseParam
3839

3940

40-
def get_session() -> Session:
41-
"""
42-
Dependency for providing a session.
43-
44-
For non route function please use the :class:`airflow.utils.session.provide_session` decorator.
45-
46-
Example usage:
47-
48-
.. code:: python
49-
50-
@router.get("/your_path")
51-
def your_route(session: Annotated[Session, Depends(get_session)]):
52-
pass
53-
"""
41+
def _get_session() -> Session:
5442
with create_session(scoped=False) as session:
5543
yield session
5644

5745

46+
SessionDep = Annotated[Session, Depends(_get_session)]
47+
48+
5849
def apply_filters_to_select(
5950
*, statement: Select, filters: Sequence[BaseParam | None] | None = None
6051
) -> Select:
@@ -68,22 +59,14 @@ def apply_filters_to_select(
6859
return statement
6960

7061

71-
async def get_async_session() -> AsyncSession:
72-
"""
73-
Dependency for providing a session.
74-
75-
Example usage:
76-
77-
.. code:: python
78-
79-
@router.get("/your_path")
80-
def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
81-
pass
82-
"""
62+
async def _get_async_session() -> AsyncSession:
8363
async with create_session_async() as session:
8464
yield session
8565

8666

67+
AsyncSessionDep = Annotated[AsyncSession, Depends(_get_async_session)]
68+
69+
8770
@overload
8871
async def paginated_select_async(
8972
*,

airflow/api_fastapi/core_api/routes/public/assets.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
from fastapi import Depends, HTTPException, status
2424
from sqlalchemy import delete, select
25-
from sqlalchemy.orm import Session, joinedload, subqueryload
25+
from sqlalchemy.orm import joinedload, subqueryload
2626

27-
from airflow.api_fastapi.common.db.common import get_session, paginated_select
27+
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
2828
from airflow.api_fastapi.common.parameters import (
2929
OptionalDateTimeQuery,
3030
QueryAssetDagIdPatternSearch,
@@ -91,7 +91,7 @@ def get_assets(
9191
SortParam,
9292
Depends(SortParam(["id", "uri", "created_at", "updated_at"], AssetModel).dynamic_depends()),
9393
],
94-
session: Annotated[Session, Depends(get_session)],
94+
session: SessionDep,
9595
) -> AssetCollectionResponse:
9696
"""Get assets."""
9797
assets_select, total_entries = paginated_select(
@@ -141,7 +141,7 @@ def get_asset_events(
141141
source_task_id: QuerySourceTaskIdFilter,
142142
source_run_id: QuerySourceRunIdFilter,
143143
source_map_index: QuerySourceMapIndexFilter,
144-
session: Annotated[Session, Depends(get_session)],
144+
session: SessionDep,
145145
) -> AssetEventCollectionResponse:
146146
"""Get asset events."""
147147
assets_event_select, total_entries = paginated_select(
@@ -168,7 +168,7 @@ def get_asset_events(
168168
)
169169
def create_asset_event(
170170
body: CreateAssetEventsBody,
171-
session: Annotated[Session, Depends(get_session)],
171+
session: SessionDep,
172172
) -> AssetEventResponse:
173173
"""Create asset events."""
174174
asset = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1))
@@ -198,7 +198,7 @@ def create_asset_event(
198198
)
199199
def get_asset_queued_events(
200200
uri: str,
201-
session: Annotated[Session, Depends(get_session)],
201+
session: SessionDep,
202202
before: OptionalDateTimeQuery = None,
203203
) -> QueuedEventCollectionResponse:
204204
"""Get queued asset events for an asset."""
@@ -233,7 +233,7 @@ def get_asset_queued_events(
233233
)
234234
def get_asset(
235235
uri: str,
236-
session: Annotated[Session, Depends(get_session)],
236+
session: SessionDep,
237237
) -> AssetResponse:
238238
"""Get an asset."""
239239
asset = session.scalar(
@@ -258,7 +258,7 @@ def get_asset(
258258
)
259259
def get_dag_asset_queued_events(
260260
dag_id: str,
261-
session: Annotated[Session, Depends(get_session)],
261+
session: SessionDep,
262262
before: OptionalDateTimeQuery = None,
263263
) -> QueuedEventCollectionResponse:
264264
"""Get queued asset events for a DAG."""
@@ -296,7 +296,7 @@ def get_dag_asset_queued_events(
296296
def get_dag_asset_queued_event(
297297
dag_id: str,
298298
uri: str,
299-
session: Annotated[Session, Depends(get_session)],
299+
session: SessionDep,
300300
before: OptionalDateTimeQuery = None,
301301
) -> QueuedEventResponse:
302302
"""Get a queued asset event for a DAG."""
@@ -327,7 +327,7 @@ def get_dag_asset_queued_event(
327327
)
328328
def delete_asset_queued_events(
329329
uri: str,
330-
session: Annotated[Session, Depends(get_session)],
330+
session: SessionDep,
331331
before: OptionalDateTimeQuery = None,
332332
):
333333
"""Delete queued asset events for an asset."""
@@ -350,7 +350,7 @@ def delete_asset_queued_events(
350350
)
351351
def delete_dag_asset_queued_events(
352352
dag_id: str,
353-
session: Annotated[Session, Depends(get_session)],
353+
session: SessionDep,
354354
before: OptionalDateTimeQuery = None,
355355
):
356356
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before)
@@ -375,7 +375,7 @@ def delete_dag_asset_queued_events(
375375
def delete_dag_asset_queued_event(
376376
dag_id: str,
377377
uri: str,
378-
session: Annotated[Session, Depends(get_session)],
378+
session: SessionDep,
379379
before: OptionalDateTimeQuery = None,
380380
):
381381
"""Delete a queued asset event for a DAG."""

airflow/api_fastapi/core_api/routes/public/backfills.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
from fastapi import Depends, HTTPException, status
2222
from sqlalchemy import select, update
23-
from sqlalchemy.ext.asyncio import AsyncSession
24-
from sqlalchemy.orm import Session
2523

26-
from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async
24+
from airflow.api_fastapi.common.db.common import (
25+
AsyncSessionDep,
26+
SessionDep,
27+
paginated_select_async,
28+
)
2729
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
2830
from airflow.api_fastapi.common.router import AirflowRouter
2931
from airflow.api_fastapi.core_api.datamodels.backfills import (
@@ -58,7 +60,7 @@ async def list_backfills(
5860
SortParam,
5961
Depends(SortParam(["id"], Backfill).dynamic_depends()),
6062
],
61-
session: Annotated[AsyncSession, Depends(get_async_session)],
63+
session: AsyncSessionDep,
6264
) -> BackfillCollectionResponse:
6365
select_stmt, total_entries = await paginated_select_async(
6466
statement=select(Backfill).where(Backfill.dag_id == dag_id),
@@ -80,7 +82,7 @@ async def list_backfills(
8082
)
8183
def get_backfill(
8284
backfill_id: str,
83-
session: Annotated[Session, Depends(get_session)],
85+
session: SessionDep,
8486
) -> BackfillResponse:
8587
backfill = session.get(Backfill, backfill_id)
8688
if backfill:
@@ -97,7 +99,7 @@ def get_backfill(
9799
]
98100
),
99101
)
100-
def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
102+
def pause_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
101103
b = session.get(Backfill, backfill_id)
102104
if not b:
103105
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")
@@ -118,7 +120,7 @@ def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)
118120
]
119121
),
120122
)
121-
def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
123+
def unpause_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
122124
b = session.get(Backfill, backfill_id)
123125
if not b:
124126
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")
@@ -138,7 +140,7 @@ def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_sessio
138140
]
139141
),
140142
)
141-
def cancel_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
143+
def cancel_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
142144
b: Backfill = session.get(Backfill, backfill_id)
143145
if not b:
144146
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")

airflow/api_fastapi/core_api/routes/public/connections.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121

2222
from fastapi import Depends, HTTPException, Query, status
2323
from sqlalchemy import select
24-
from sqlalchemy.orm import Session
2524

26-
from airflow.api_fastapi.common.db.common import get_session, paginated_select
25+
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
2726
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
2827
from airflow.api_fastapi.common.router import AirflowRouter
2928
from airflow.api_fastapi.core_api.datamodels.connections import (
@@ -49,7 +48,7 @@
4948
)
5049
def delete_connection(
5150
connection_id: str,
52-
session: Annotated[Session, Depends(get_session)],
51+
session: SessionDep,
5352
):
5453
"""Delete a connection entry."""
5554
connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))
@@ -68,7 +67,7 @@ def delete_connection(
6867
)
6968
def get_connection(
7069
connection_id: str,
71-
session: Annotated[Session, Depends(get_session)],
70+
session: SessionDep,
7271
) -> ConnectionResponse:
7372
"""Get a connection entry."""
7473
connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))
@@ -98,7 +97,7 @@ def get_connections(
9897
).dynamic_depends()
9998
),
10099
],
101-
session: Annotated[Session, Depends(get_session)],
100+
session: SessionDep,
102101
) -> ConnectionCollectionResponse:
103102
"""Get all connection entries."""
104103
connection_select, total_entries = paginated_select(
@@ -124,7 +123,7 @@ def get_connections(
124123
)
125124
def post_connection(
126125
post_body: ConnectionBody,
127-
session: Annotated[Session, Depends(get_session)],
126+
session: SessionDep,
128127
) -> ConnectionResponse:
129128
"""Create connection entry."""
130129
try:
@@ -157,7 +156,7 @@ def post_connection(
157156
def patch_connection(
158157
connection_id: str,
159158
patch_body: ConnectionBody,
160-
session: Annotated[Session, Depends(get_session)],
159+
session: SessionDep,
161160
update_mask: list[str] | None = Query(None),
162161
) -> ConnectionResponse:
163162
"""Update a connection entry."""

airflow/api_fastapi/core_api/routes/public/dag_parsing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
from __future__ import annotations
1818

1919
from collections.abc import Sequence
20-
from typing import TYPE_CHECKING, Annotated
20+
from typing import TYPE_CHECKING
2121

22-
from fastapi import Depends, HTTPException, Request, status
22+
from fastapi import HTTPException, Request, status
2323
from itsdangerous import BadSignature, URLSafeSerializer
2424
from sqlalchemy import select
25-
from sqlalchemy.orm import Session
2625

27-
from airflow.api_fastapi.common.db.common import get_session
26+
from airflow.api_fastapi.common.db.common import SessionDep
2827
from airflow.api_fastapi.common.router import AirflowRouter
2928
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
3029
from airflow.auth.managers.models.resource_details import DagDetails
@@ -44,7 +43,7 @@
4443
)
4544
def reparse_dag_file(
4645
file_token: str,
47-
session: Annotated[Session, Depends(get_session)],
46+
session: SessionDep,
4847
request: Request,
4948
) -> None:
5049
"""Request re-parsing a DAG file."""

0 commit comments

Comments
 (0)