|
19 | 19 |
|
20 | 20 | from __future__ import annotations |
21 | 21 |
|
22 | | -import contextlib |
23 | 22 | from collections.abc import ( |
| 23 | + Callable, |
24 | 24 | Container, |
25 | 25 | Iterator, |
26 | 26 | Mapping, |
|
51 | 51 | AssetRef, |
52 | 52 | AssetUniqueKey, |
53 | 53 | AssetUriRef, |
54 | | - BaseAssetUniqueKey, |
55 | 54 | ) |
56 | 55 | from airflow.sdk.definitions.context import Context |
| 56 | +from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK |
57 | 57 | from airflow.utils.db import LazySelectSequence |
58 | 58 | from airflow.utils.session import create_session |
59 | 59 | from airflow.utils.types import NOTSET |
@@ -156,104 +156,27 @@ def get(self, key: str, default_conn: Any = None) -> Any: |
156 | 156 | return default_conn |
157 | 157 |
|
158 | 158 |
|
159 | | -@attrs.define() |
160 | | -class AssetAliasEvent: |
161 | | - """ |
162 | | - Represeation of asset event to be triggered by an asset alias. |
163 | | -
|
164 | | - :meta private: |
165 | | - """ |
| 159 | +def _get_asset(name: str | None = None, uri: str | None = None) -> Asset: |
| 160 | + if name: |
| 161 | + with create_session() as session: |
| 162 | + asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())) |
| 163 | + elif uri: |
| 164 | + with create_session() as session: |
| 165 | + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())) |
| 166 | + else: |
| 167 | + raise ValueError("Either name or uri must be provided") |
166 | 168 |
|
167 | | - source_alias_name: str |
168 | | - dest_asset_key: AssetUniqueKey |
169 | | - extra: dict[str, Any] |
| 169 | + return asset.to_public() |
170 | 170 |
|
171 | 171 |
|
172 | | -@attrs.define() |
173 | | -class OutletEventAccessor: |
174 | | - """ |
175 | | - Wrapper to access an outlet asset event in template. |
176 | | -
|
177 | | - :meta private: |
178 | | - """ |
179 | | - |
180 | | - key: BaseAssetUniqueKey |
181 | | - extra: dict[str, Any] = attrs.Factory(dict) |
182 | | - asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) |
183 | | - |
184 | | - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: |
185 | | - """Add an AssetEvent to an existing Asset.""" |
186 | | - if not isinstance(self.key, AssetAliasUniqueKey): |
187 | | - return |
188 | | - |
189 | | - asset_alias_name = self.key.name |
190 | | - event = AssetAliasEvent( |
191 | | - source_alias_name=asset_alias_name, |
192 | | - dest_asset_key=AssetUniqueKey.from_asset(asset), |
193 | | - extra=extra or {}, |
194 | | - ) |
195 | | - self.asset_alias_events.append(event) |
196 | | - |
197 | | - |
198 | | -class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): |
| 172 | +class OutletEventAccessors(OutletEventAccessorsSDK): |
199 | 173 | """ |
200 | 174 | Lazy mapping of outlet asset event accessors. |
201 | 175 |
|
202 | 176 | :meta private: |
203 | 177 | """ |
204 | 178 |
|
205 | | - _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} |
206 | | - |
207 | | - def __init__(self) -> None: |
208 | | - self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} |
209 | | - |
210 | | - def __str__(self) -> str: |
211 | | - return f"OutletEventAccessors(_dict={self._dict})" |
212 | | - |
213 | | - def __iter__(self) -> Iterator[Asset | AssetAlias]: |
214 | | - return ( |
215 | | - key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict |
216 | | - ) |
217 | | - |
218 | | - def __len__(self) -> int: |
219 | | - return len(self._dict) |
220 | | - |
221 | | - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: |
222 | | - hashable_key: BaseAssetUniqueKey |
223 | | - if isinstance(key, Asset): |
224 | | - hashable_key = AssetUniqueKey.from_asset(key) |
225 | | - elif isinstance(key, AssetAlias): |
226 | | - hashable_key = AssetAliasUniqueKey.from_asset_alias(key) |
227 | | - elif isinstance(key, AssetRef): |
228 | | - hashable_key = self._resolve_asset_ref(key) |
229 | | - else: |
230 | | - raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") |
231 | | - |
232 | | - if hashable_key not in self._dict: |
233 | | - self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) |
234 | | - return self._dict[hashable_key] |
235 | | - |
236 | | - def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: |
237 | | - with contextlib.suppress(KeyError): |
238 | | - return self._asset_ref_cache[ref] |
239 | | - |
240 | | - refs_to_cache: list[AssetRef] |
241 | | - with create_session() as session: |
242 | | - if isinstance(ref, AssetNameRef): |
243 | | - asset = session.scalar( |
244 | | - select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has()) |
245 | | - ) |
246 | | - refs_to_cache = [ref, AssetUriRef(asset.uri)] |
247 | | - elif isinstance(ref, AssetUriRef): |
248 | | - asset = session.scalar( |
249 | | - select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has()) |
250 | | - ) |
251 | | - refs_to_cache = [ref, AssetNameRef(asset.name)] |
252 | | - else: |
253 | | - raise TypeError(f"Unimplemented asset ref: {type(ref)}") |
254 | | - for ref in refs_to_cache: |
255 | | - self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset) |
256 | | - return unique_key |
| 179 | + _get_asset_func: Callable[..., Asset] = _get_asset |
257 | 180 |
|
258 | 181 |
|
259 | 182 | class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): |
|
0 commit comments