Skip to content

Commit 943ecd7

Browse files
committed
AIP-72: Add support for outlet_events in Task Context
part of #45717 This PR adds support for `outlet_events` in Context dict within the Task SDK by adding an endpoint on the API Server which is fetched when outlet_events is accessed.
1 parent 060eeb7 commit 943ecd7

17 files changed

Lines changed: 679 additions & 105 deletions

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from datetime import datetime
21+
22+
from airflow.api_fastapi.core_api.base import BaseModel
23+
24+
25+
class DagScheduleAssetReference(BaseModel):
26+
"""DAG schedule reference schema for assets."""
27+
28+
dag_id: str
29+
created_at: datetime
30+
updated_at: datetime
31+
32+
33+
class TaskOutletAssetReference(BaseModel):
34+
"""Task outlet reference schema for assets."""
35+
36+
dag_id: str
37+
task_id: str
38+
created_at: datetime
39+
updated_at: datetime
40+
41+
42+
class AssetResponse(BaseModel):
43+
"""Asset schema for responses."""
44+
45+
id: int
46+
name: str
47+
uri: str
48+
group: str
49+
extra: dict | None = None
50+
created_at: datetime
51+
updated_at: datetime
52+
consuming_dags: list[DagScheduleAssetReference]
53+
producing_tasks: list[TaskOutletAssetReference]
54+
aliases: list[AssetAliasResponse]
55+
56+
57+
class AssetAliasResponse(BaseModel):
58+
"""Asset alias schema for responses."""
59+
60+
id: int
61+
name: str
62+
group: str

airflow/api_fastapi/execution_api/routes/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@
1717
from __future__ import annotations
1818

1919
from airflow.api_fastapi.common.router import AirflowRouter
20-
from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms
20+
from airflow.api_fastapi.execution_api.routes import (
21+
assets,
22+
connections,
23+
health,
24+
task_instances,
25+
variables,
26+
xcoms,
27+
)
2128

2229
execution_api_router = AirflowRouter()
30+
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
2331
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
2432
execution_api_router.include_router(health.router, tags=["Health"])
2533
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import Annotated
21+
22+
from fastapi import HTTPException, Query, status
23+
from sqlalchemy import select
24+
25+
from airflow.api_fastapi.common.db.common import SessionDep
26+
from airflow.api_fastapi.common.router import AirflowRouter
27+
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
28+
from airflow.models.asset import AssetModel
29+
30+
# TODO: Add dependency on JWT token
31+
router = AirflowRouter(
32+
responses={
33+
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
34+
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
35+
},
36+
)
37+
38+
39+
@router.get(
40+
"",
41+
responses={
42+
status.HTTP_400_BAD_REQUEST: {
43+
"description": "Either 'name' or 'uri' query parameter must be provided"
44+
},
45+
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
46+
},
47+
)
48+
def get_asset(
49+
session: SessionDep,
50+
name: Annotated[str | None, Query(description="The name of the Asset")] = None,
51+
uri: Annotated[str | None, Query(description="The URI of the Asset")] = None,
52+
) -> AssetResponse:
53+
"""Get an Airflow Asset by `name` or `uri`."""
54+
if name:
55+
asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()))
56+
_raise_if_not_found(asset, f"Asset with name {name} not found")
57+
elif uri:
58+
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()))
59+
_raise_if_not_found(asset, f"Asset with URI {uri} not found")
60+
else:
61+
raise HTTPException(
62+
status.HTTP_400_BAD_REQUEST,
63+
detail={
64+
"reason": "bad_request",
65+
"message": "Either 'name' or 'uri' query parameter must be provided",
66+
},
67+
)
68+
return AssetResponse.model_validate(asset)
69+
70+
71+
def _raise_if_not_found(asset, msg):
72+
if asset is None:
73+
raise HTTPException(
74+
status.HTTP_404_NOT_FOUND,
75+
detail={
76+
"reason": "not_found",
77+
"message": msg,
78+
},
79+
)

airflow/serialization/serialized_objects.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
BaseAsset,
6565
)
6666
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
67+
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
6768
from airflow.serialization.dag_dependency import DagDependency
6869
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
6970
from airflow.serialization.helpers import serialize_template_field
@@ -77,10 +78,8 @@
7778
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
7879
from airflow.utils.code_utils import get_python_source
7980
from airflow.utils.context import (
80-
AssetAliasEvent,
8181
ConnectionAccessor,
8282
Context,
83-
OutletEventAccessor,
8483
OutletEventAccessors,
8584
VariableAccessor,
8685
)

airflow/utils/context.py

Lines changed: 14 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
from __future__ import annotations
2121

22-
import contextlib
2322
from collections.abc import (
23+
Callable,
2424
Container,
2525
Iterator,
2626
Mapping,
@@ -51,9 +51,9 @@
5151
AssetRef,
5252
AssetUniqueKey,
5353
AssetUriRef,
54-
BaseAssetUniqueKey,
5554
)
5655
from airflow.sdk.definitions.context import Context
56+
from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK
5757
from airflow.utils.db import LazySelectSequence
5858
from airflow.utils.session import create_session
5959
from airflow.utils.types import NOTSET
@@ -156,104 +156,27 @@ def get(self, key: str, default_conn: Any = None) -> Any:
156156
return default_conn
157157

158158

159-
@attrs.define()
160-
class AssetAliasEvent:
161-
"""
162-
Represeation of asset event to be triggered by an asset alias.
163-
164-
:meta private:
165-
"""
159+
def _get_asset(name: str | None = None, uri: str | None = None) -> Asset:
160+
if name:
161+
with create_session() as session:
162+
asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()))
163+
elif uri:
164+
with create_session() as session:
165+
asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()))
166+
else:
167+
raise ValueError("Either name or uri must be provided")
166168

167-
source_alias_name: str
168-
dest_asset_key: AssetUniqueKey
169-
extra: dict[str, Any]
169+
return asset.to_public()
170170

171171

172-
@attrs.define()
173-
class OutletEventAccessor:
174-
"""
175-
Wrapper to access an outlet asset event in template.
176-
177-
:meta private:
178-
"""
179-
180-
key: BaseAssetUniqueKey
181-
extra: dict[str, Any] = attrs.Factory(dict)
182-
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
183-
184-
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
185-
"""Add an AssetEvent to an existing Asset."""
186-
if not isinstance(self.key, AssetAliasUniqueKey):
187-
return
188-
189-
asset_alias_name = self.key.name
190-
event = AssetAliasEvent(
191-
source_alias_name=asset_alias_name,
192-
dest_asset_key=AssetUniqueKey.from_asset(asset),
193-
extra=extra or {},
194-
)
195-
self.asset_alias_events.append(event)
196-
197-
198-
class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]):
172+
class OutletEventAccessors(OutletEventAccessorsSDK):
199173
"""
200174
Lazy mapping of outlet asset event accessors.
201175
202176
:meta private:
203177
"""
204178

205-
_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
206-
207-
def __init__(self) -> None:
208-
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
209-
210-
def __str__(self) -> str:
211-
return f"OutletEventAccessors(_dict={self._dict})"
212-
213-
def __iter__(self) -> Iterator[Asset | AssetAlias]:
214-
return (
215-
key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict
216-
)
217-
218-
def __len__(self) -> int:
219-
return len(self._dict)
220-
221-
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
222-
hashable_key: BaseAssetUniqueKey
223-
if isinstance(key, Asset):
224-
hashable_key = AssetUniqueKey.from_asset(key)
225-
elif isinstance(key, AssetAlias):
226-
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
227-
elif isinstance(key, AssetRef):
228-
hashable_key = self._resolve_asset_ref(key)
229-
else:
230-
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
231-
232-
if hashable_key not in self._dict:
233-
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
234-
return self._dict[hashable_key]
235-
236-
def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
237-
with contextlib.suppress(KeyError):
238-
return self._asset_ref_cache[ref]
239-
240-
refs_to_cache: list[AssetRef]
241-
with create_session() as session:
242-
if isinstance(ref, AssetNameRef):
243-
asset = session.scalar(
244-
select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has())
245-
)
246-
refs_to_cache = [ref, AssetUriRef(asset.uri)]
247-
elif isinstance(ref, AssetUriRef):
248-
asset = session.scalar(
249-
select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has())
250-
)
251-
refs_to_cache = [ref, AssetNameRef(asset.name)]
252-
else:
253-
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
254-
for ref in refs_to_cache:
255-
self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset)
256-
return unique_key
179+
_get_asset_func: Callable[..., Asset] = _get_asset
257180

258181

259182
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):

task_sdk/src/airflow/sdk/api/client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from airflow.sdk import __version__
3636
from airflow.sdk.api.datamodels._generated import (
37+
AssetResponse,
3738
ConnectionResponse,
3839
DagRunType,
3940
TerminalTIState,
@@ -267,6 +268,25 @@ def set(
267268
return {"ok": True}
268269

269270

271+
class AssetOperations:
272+
__slots__ = ("client",)
273+
274+
def __init__(self, client: Client):
275+
self.client = client
276+
277+
def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse:
278+
"""Get Asset value from the API server."""
279+
if name:
280+
params = {"name": name}
281+
elif uri:
282+
params = {"uri": uri}
283+
else:
284+
raise ValueError("Either `name` or `uri` must be provided")
285+
286+
resp = self.client.get("assets/", params=params)
287+
return AssetResponse.model_validate_json(resp.read())
288+
289+
270290
class BearerAuth(httpx.Auth):
271291
def __init__(self, token: str):
272292
self.token: str = token
@@ -374,6 +394,12 @@ def xcoms(self) -> XComOperations:
374394
"""Operations related to XComs."""
375395
return XComOperations(self)
376396

397+
@lru_cache() # type: ignore[misc]
398+
@property
399+
def assets(self) -> AssetOperations:
400+
"""Operations related to XComs."""
401+
return AssetOperations(self)
402+
377403

378404
# This is only used for parsing. ServerResponseError is raised instead
379405
class _ErrorBody(BaseModel):

0 commit comments

Comments
 (0)