|
21 | 21 |
|
22 | 22 | from collections.abc import ( |
23 | 23 | Container, |
24 | | - Iterator, |
25 | | - Mapping, |
26 | 24 | ) |
27 | 25 | from typing import ( |
28 | 26 | TYPE_CHECKING, |
29 | 27 | Any, |
30 | | - Union, |
31 | 28 | cast, |
32 | 29 | ) |
33 | 30 |
|
34 | 31 | import attrs |
35 | | -from sqlalchemy import and_, select |
| 32 | +from sqlalchemy import select |
36 | 33 |
|
37 | 34 | from airflow.models.asset import ( |
38 | | - AssetAliasModel, |
39 | | - AssetEvent, |
40 | 35 | AssetModel, |
41 | | - fetch_active_assets_by_name, |
42 | | - fetch_active_assets_by_uri, |
43 | 36 | ) |
44 | 37 | from airflow.sdk.definitions.asset import ( |
45 | 38 | Asset, |
46 | | - AssetAlias, |
47 | | - AssetAliasUniqueKey, |
48 | | - AssetNameRef, |
49 | | - AssetRef, |
50 | | - AssetUniqueKey, |
51 | | - AssetUriRef, |
52 | 39 | ) |
53 | 40 | from airflow.sdk.definitions.context import Context |
54 | 41 | from airflow.sdk.execution_time.context import ( |
55 | 42 | ConnectionAccessor as ConnectionAccessorSDK, |
| 43 | + InletEventsAccessors as InletEventsAccessorsSDK, |
56 | 44 | OutletEventAccessors as OutletEventAccessorsSDK, |
57 | 45 | VariableAccessor as VariableAccessorSDK, |
58 | 46 | ) |
59 | | -from airflow.utils.db import LazySelectSequence |
60 | 47 | from airflow.utils.session import create_session |
61 | 48 | from airflow.utils.types import NOTSET |
62 | 49 |
|
63 | 50 | if TYPE_CHECKING: |
64 | | - from sqlalchemy.engine import Row |
65 | | - from sqlalchemy.orm import Session |
66 | | - from sqlalchemy.sql.expression import Select, TextClause |
67 | | - |
68 | 51 | from airflow.sdk.types import OutletEventAccessorsProtocol |
69 | 52 |
|
70 | 53 | # NOTE: Please keep this in sync with the following: |
@@ -170,105 +153,14 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset |
170 | 153 | return asset.to_public() |
171 | 154 |
|
172 | 155 |
|
173 | | -class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): |
174 | | - """ |
175 | | - List-like interface to lazily access AssetEvent rows. |
176 | | -
|
177 | | - :meta private: |
178 | | - """ |
179 | | - |
180 | | - @staticmethod |
181 | | - def _rebuild_select(stmt: TextClause) -> Select: |
182 | | - return select(AssetEvent).from_statement(stmt) |
183 | | - |
184 | | - @staticmethod |
185 | | - def _process_row(row: Row) -> AssetEvent: |
186 | | - return row[0] |
187 | | - |
188 | | - |
189 | 156 | @attrs.define(init=False) |
190 | | -class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]): |
| 157 | +class InletEventsAccessors(InletEventsAccessorsSDK): |
191 | 158 | """ |
192 | 159 | Lazy mapping for inlet asset events accessors. |
193 | 160 |
|
194 | 161 | :meta private: |
195 | 162 | """ |
196 | 163 |
|
197 | | - _inlets: list[Any] |
198 | | - _assets: dict[AssetUniqueKey, Asset] |
199 | | - _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] |
200 | | - _session: Session |
201 | | - |
202 | | - def __init__(self, inlets: list, *, session: Session) -> None: |
203 | | - self._inlets = inlets |
204 | | - self._session = session |
205 | | - self._assets = {} |
206 | | - self._asset_aliases = {} |
207 | | - |
208 | | - _asset_ref_names: list[str] = [] |
209 | | - _asset_ref_uris: list[str] = [] |
210 | | - for inlet in inlets: |
211 | | - if isinstance(inlet, Asset): |
212 | | - self._assets[AssetUniqueKey.from_asset(inlet)] = inlet |
213 | | - elif isinstance(inlet, AssetAlias): |
214 | | - self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet |
215 | | - elif isinstance(inlet, AssetNameRef): |
216 | | - _asset_ref_names.append(inlet.name) |
217 | | - elif isinstance(inlet, AssetUriRef): |
218 | | - _asset_ref_uris.append(inlet.uri) |
219 | | - |
220 | | - if _asset_ref_names: |
221 | | - for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items(): |
222 | | - self._assets[AssetUniqueKey.from_asset(asset)] = asset |
223 | | - if _asset_ref_uris: |
224 | | - for _, asset in fetch_active_assets_by_uri(_asset_ref_uris, self._session).items(): |
225 | | - self._assets[AssetUniqueKey.from_asset(asset)] = asset |
226 | | - |
227 | | - def __iter__(self) -> Iterator[Asset | AssetAlias]: |
228 | | - return iter(self._inlets) |
229 | | - |
230 | | - def __len__(self) -> int: |
231 | | - return len(self._inlets) |
232 | | - |
233 | | - def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence: |
234 | | - if isinstance(key, int): # Support index access; it's easier for trivial cases. |
235 | | - obj = self._inlets[key] |
236 | | - if not isinstance(obj, (Asset, AssetAlias, AssetRef)): |
237 | | - raise IndexError(key) |
238 | | - else: |
239 | | - obj = key |
240 | | - |
241 | | - if isinstance(obj, Asset): |
242 | | - asset = self._assets[AssetUniqueKey.from_asset(obj)] |
243 | | - join_clause = AssetEvent.asset |
244 | | - where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri) |
245 | | - elif isinstance(obj, AssetAlias): |
246 | | - asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] |
247 | | - join_clause = AssetEvent.source_aliases |
248 | | - where_clause = AssetAliasModel.name == asset_alias.name |
249 | | - elif isinstance(obj, AssetNameRef): |
250 | | - try: |
251 | | - asset = next(a for k, a in self._assets.items() if k.name == obj.name) |
252 | | - except StopIteration: |
253 | | - raise KeyError(obj) from None |
254 | | - join_clause = AssetEvent.asset |
255 | | - where_clause = and_(AssetModel.name == asset.name, AssetModel.active.has()) |
256 | | - elif isinstance(obj, AssetUriRef): |
257 | | - try: |
258 | | - asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) |
259 | | - except StopIteration: |
260 | | - raise KeyError(obj) from None |
261 | | - join_clause = AssetEvent.asset |
262 | | - where_clause = and_(AssetModel.uri == asset.uri, AssetModel.active.has()) |
263 | | - else: |
264 | | - raise ValueError(key) |
265 | | - |
266 | | - return LazyAssetEventSelectSequence.from_select( |
267 | | - select(AssetEvent).join(join_clause).where(where_clause), |
268 | | - order_by=[AssetEvent.timestamp], |
269 | | - session=self._session, |
270 | | - ) |
271 | | - |
272 | 164 |
|
273 | 165 | def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: |
274 | 166 | """ |
|
0 commit comments