Skip to content

Commit 0c8da84

Browse files
committed
feat(task_sdk): add support for inlet_events in Task Context
1 parent 0cd8547 commit 0c8da84

10 files changed

Lines changed: 432 additions & 148 deletions

File tree

airflow/api_fastapi/execution_api/datamodels/asset.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
from __future__ import annotations
1919

20+
from datetime import datetime
21+
22+
from pydantic import Field, field_validator
23+
2024
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
25+
from airflow.sdk.execution_time.secrets_masker import redact
2126

2227

2328
class AssetResponse(BaseModel):
@@ -36,7 +41,42 @@ class AssetAliasResponse(BaseModel):
3641
group: str
3742

3843

39-
class AssetProfile(StrictBaseModel):
44+
class DagRunAssetReference(StrictBaseModel):
45+
"""DAGRun serializer for asset responses."""
46+
47+
run_id: str
48+
dag_id: str
49+
execution_date: datetime = Field(alias="logical_date")
50+
start_date: datetime
51+
end_date: datetime | None
52+
state: str
53+
data_interval_start: datetime
54+
data_interval_end: datetime
55+
56+
57+
class AssetEventResponse(BaseModel):
58+
"""Asset event schema with fields that are needed for Runtime."""
59+
60+
id: int
61+
asset_id: int
62+
uri: str | None = Field(alias="uri", default=None)
63+
name: str | None = Field(alias="name", default=None)
64+
group: str | None = Field(alias="group", default=None)
65+
extra: dict | None = None
66+
source_task_id: str | None = None
67+
source_dag_id: str | None = None
68+
source_run_id: str | None = None
69+
source_map_index: int
70+
created_dagruns: list[DagRunAssetReference]
71+
timestamp: datetime
72+
73+
@field_validator("extra", mode="after")
74+
@classmethod
75+
def redact_extra(cls, v: dict):
76+
return redact(v)
77+
78+
79+
class AssetProfile(BaseModel):
4080
"""
4181
Profile of an Asset.
4282

airflow/api_fastapi/execution_api/routes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from airflow.api_fastapi.common.router import AirflowRouter
2020
from airflow.api_fastapi.execution_api.routes import (
21+
asset_events,
2122
assets,
2223
connections,
2324
health,
@@ -28,6 +29,7 @@
2829

2930
execution_api_router = AirflowRouter()
3031
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
32+
execution_api_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
3133
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
3234
execution_api_router.include_router(health.router, tags=["Health"])
3335
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import TYPE_CHECKING, Annotated
21+
22+
from fastapi import HTTPException, Query, status
23+
from sqlalchemy import and_, select
24+
25+
from airflow.api_fastapi.common.db.common import SessionDep
26+
from airflow.api_fastapi.common.router import AirflowRouter
27+
from airflow.api_fastapi.execution_api.datamodels.asset import AssetEventResponse
28+
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
29+
from airflow.utils.db import LazySelectSequence
30+
31+
if TYPE_CHECKING:
32+
from sqlalchemy.engine import Row
33+
from sqlalchemy.sql.expression import Select, TextClause
34+
35+
# TODO: Add dependency on JWT token
36+
router = AirflowRouter(
37+
responses={
38+
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
39+
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
40+
},
41+
)
42+
43+
44+
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
45+
"""
46+
List-like interface to lazily access AssetEvent rows.
47+
48+
:meta private:
49+
"""
50+
51+
@staticmethod
52+
def _rebuild_select(stmt: TextClause) -> Select:
53+
return select(AssetEvent).from_statement(stmt)
54+
55+
@staticmethod
56+
def _process_row(row: Row) -> AssetEvent:
57+
return row[0]
58+
59+
60+
def _get_asset_event_through_sql_clause(
61+
*, join_clause, where_clause, session: SessionDep
62+
) -> AssetEventResponse:
63+
asset_event = LazyAssetEventSelectSequence.from_select(
64+
select(AssetEvent).join(join_clause).where(where_clause),
65+
order_by=[AssetEvent.timestamp],
66+
session=session,
67+
)
68+
_raise_if_not_found(asset_event=asset_event, msg="Not found")
69+
return AssetEventResponse.model_validate(asset_event)
70+
71+
72+
@router.get("/by-asset-name-uri")
73+
def get_asset_event_by_asset_name_uri(
74+
name: Annotated[str, Query(description="The name of the Asset")],
75+
uri: Annotated[str, Query(description="The URI of the Asset")],
76+
session: SessionDep,
77+
) -> AssetEventResponse:
78+
return _get_asset_event_through_sql_clause(
79+
join_clause=AssetEvent.asset,
80+
where_clause=and_(AssetModel.name == name, AssetModel.uri == uri),
81+
session=session,
82+
)
83+
84+
85+
@router.get("/by-asset-uri")
86+
def get_asset_event_by_uri(
87+
uri: Annotated[str, Query(description="The URI of the Asset")],
88+
session: SessionDep,
89+
) -> AssetEventResponse:
90+
return _get_asset_event_through_sql_clause(
91+
join_clause=AssetEvent.asset,
92+
where_clause=and_(AssetModel.uri == uri, AssetModel.active.has()),
93+
session=session,
94+
)
95+
96+
97+
@router.get("/by-asset-name")
98+
def get_asset_event_by_name(
99+
name: Annotated[str, Query(description="The name of the Asset")],
100+
session: SessionDep,
101+
) -> AssetEventResponse:
102+
return _get_asset_event_through_sql_clause(
103+
join_clause=AssetEvent.asset,
104+
where_clause=and_(AssetModel.uri == name, AssetModel.active.has()),
105+
session=session,
106+
)
107+
108+
109+
@router.get("/by-alias-name")
110+
def get_asset_event_by_alias_name(
111+
name: Annotated[str, Query(description="The name of the Asset Alias")],
112+
session: SessionDep,
113+
) -> AssetEventResponse:
114+
return _get_asset_event_through_sql_clause(
115+
join_clause=AssetEvent.source_aliases,
116+
where_clause=(AssetAliasModel.name == name),
117+
session=session,
118+
)
119+
120+
121+
def _raise_if_not_found(asset_event, msg):
122+
if asset_event is None:
123+
raise HTTPException(
124+
status.HTTP_404_NOT_FOUND,
125+
detail={
126+
"reason": "not_found",
127+
"message": msg,
128+
},
129+
)

airflow/utils/context.py

Lines changed: 3 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,33 @@
2121

2222
from collections.abc import (
2323
Container,
24-
Iterator,
25-
Mapping,
2624
)
2725
from typing import (
2826
TYPE_CHECKING,
2927
Any,
30-
Union,
3128
cast,
3229
)
3330

3431
import attrs
35-
from sqlalchemy import and_, select
32+
from sqlalchemy import select
3633

3734
from airflow.models.asset import (
38-
AssetAliasModel,
39-
AssetEvent,
4035
AssetModel,
41-
fetch_active_assets_by_name,
42-
fetch_active_assets_by_uri,
4336
)
4437
from airflow.sdk.definitions.asset import (
4538
Asset,
46-
AssetAlias,
47-
AssetAliasUniqueKey,
48-
AssetNameRef,
49-
AssetRef,
50-
AssetUniqueKey,
51-
AssetUriRef,
5239
)
5340
from airflow.sdk.definitions.context import Context
5441
from airflow.sdk.execution_time.context import (
5542
ConnectionAccessor as ConnectionAccessorSDK,
43+
InletEventsAccessors as InletEventsAccessorsSDK,
5644
OutletEventAccessors as OutletEventAccessorsSDK,
5745
VariableAccessor as VariableAccessorSDK,
5846
)
59-
from airflow.utils.db import LazySelectSequence
6047
from airflow.utils.session import create_session
6148
from airflow.utils.types import NOTSET
6249

6350
if TYPE_CHECKING:
64-
from sqlalchemy.engine import Row
65-
from sqlalchemy.orm import Session
66-
from sqlalchemy.sql.expression import Select, TextClause
67-
6851
from airflow.sdk.types import OutletEventAccessorsProtocol
6952

7053
# NOTE: Please keep this in sync with the following:
@@ -170,105 +153,14 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
170153
return asset.to_public()
171154

172155

173-
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
174-
"""
175-
List-like interface to lazily access AssetEvent rows.
176-
177-
:meta private:
178-
"""
179-
180-
@staticmethod
181-
def _rebuild_select(stmt: TextClause) -> Select:
182-
return select(AssetEvent).from_statement(stmt)
183-
184-
@staticmethod
185-
def _process_row(row: Row) -> AssetEvent:
186-
return row[0]
187-
188-
189156
@attrs.define(init=False)
190-
class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]):
157+
class InletEventsAccessors(InletEventsAccessorsSDK):
191158
"""
192159
Lazy mapping for inlet asset events accessors.
193160
194161
:meta private:
195162
"""
196163

197-
_inlets: list[Any]
198-
_assets: dict[AssetUniqueKey, Asset]
199-
_asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
200-
_session: Session
201-
202-
def __init__(self, inlets: list, *, session: Session) -> None:
203-
self._inlets = inlets
204-
self._session = session
205-
self._assets = {}
206-
self._asset_aliases = {}
207-
208-
_asset_ref_names: list[str] = []
209-
_asset_ref_uris: list[str] = []
210-
for inlet in inlets:
211-
if isinstance(inlet, Asset):
212-
self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
213-
elif isinstance(inlet, AssetAlias):
214-
self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
215-
elif isinstance(inlet, AssetNameRef):
216-
_asset_ref_names.append(inlet.name)
217-
elif isinstance(inlet, AssetUriRef):
218-
_asset_ref_uris.append(inlet.uri)
219-
220-
if _asset_ref_names:
221-
for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
222-
self._assets[AssetUniqueKey.from_asset(asset)] = asset
223-
if _asset_ref_uris:
224-
for _, asset in fetch_active_assets_by_uri(_asset_ref_uris, self._session).items():
225-
self._assets[AssetUniqueKey.from_asset(asset)] = asset
226-
227-
def __iter__(self) -> Iterator[Asset | AssetAlias]:
228-
return iter(self._inlets)
229-
230-
def __len__(self) -> int:
231-
return len(self._inlets)
232-
233-
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence:
234-
if isinstance(key, int): # Support index access; it's easier for trivial cases.
235-
obj = self._inlets[key]
236-
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
237-
raise IndexError(key)
238-
else:
239-
obj = key
240-
241-
if isinstance(obj, Asset):
242-
asset = self._assets[AssetUniqueKey.from_asset(obj)]
243-
join_clause = AssetEvent.asset
244-
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
245-
elif isinstance(obj, AssetAlias):
246-
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
247-
join_clause = AssetEvent.source_aliases
248-
where_clause = AssetAliasModel.name == asset_alias.name
249-
elif isinstance(obj, AssetNameRef):
250-
try:
251-
asset = next(a for k, a in self._assets.items() if k.name == obj.name)
252-
except StopIteration:
253-
raise KeyError(obj) from None
254-
join_clause = AssetEvent.asset
255-
where_clause = and_(AssetModel.name == asset.name, AssetModel.active.has())
256-
elif isinstance(obj, AssetUriRef):
257-
try:
258-
asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
259-
except StopIteration:
260-
raise KeyError(obj) from None
261-
join_clause = AssetEvent.asset
262-
where_clause = and_(AssetModel.uri == asset.uri, AssetModel.active.has())
263-
else:
264-
raise ValueError(key)
265-
266-
return LazyAssetEventSelectSequence.from_select(
267-
select(AssetEvent).join(join_clause).where(where_clause),
268-
order_by=[AssetEvent.timestamp],
269-
session=self._session,
270-
)
271-
272164

273165
def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
274166
"""

0 commit comments

Comments
 (0)