Skip to content

Commit 5107c22

Browse files
committed
Use Protocol for OutletEventAccessor
Follow-up of #45727 to use Protocol to allow auto-completion on IDE while not introducing runtime dep
1 parent 0f8707e commit 5107c22

13 files changed

Lines changed: 63 additions & 28 deletions

File tree

airflow/models/taskinstance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
from airflow.models.dagrun import DagRun
164164
from airflow.models.operator import Operator
165165
from airflow.sdk.definitions.dag import DAG
166-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
166+
from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol
167167
from airflow.timetables.base import DataInterval
168168
from airflow.typing_compat import Literal, TypeGuard
169169
from airflow.utils.task_group import TaskGroup
@@ -2730,7 +2730,7 @@ def _run_raw_task(
27302730
)
27312731

27322732
def _register_asset_changes(
2733-
self, *, events: OutletEventAccessors, session: Session | None = None
2733+
self, *, events: OutletEventAccessorsProtocol, session: Session | None = None
27342734
) -> None:
27352735
if session:
27362736
TaskInstance._register_asset_changes_int(ti=self, events=events, session=session)
@@ -2740,7 +2740,7 @@ def _register_asset_changes(
27402740
@staticmethod
27412741
@provide_session
27422742
def _register_asset_changes_int(
2743-
ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION
2743+
ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION
27442744
) -> None:
27452745
if TYPE_CHECKING:
27462746
assert ti.task

airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from airflow.sdk.definitions.asset import (
5757
Asset,
5858
AssetAlias,
59+
AssetAliasEvent,
5960
AssetAliasUniqueKey,
6061
AssetAll,
6162
AssetAny,
@@ -64,7 +65,7 @@
6465
BaseAsset,
6566
)
6667
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
67-
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
68+
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
6869
from airflow.serialization.dag_dependency import DagDependency
6970
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
7071
from airflow.serialization.helpers import serialize_template_field
@@ -80,7 +81,6 @@
8081
from airflow.utils.context import (
8182
ConnectionAccessor,
8283
Context,
83-
OutletEventAccessors,
8484
VariableAccessor,
8585
)
8686
from airflow.utils.db import LazySelectSequence

airflow/utils/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from sqlalchemy.sql.expression import Select, TextClause
6464

6565
from airflow.models.baseoperator import BaseOperator
66+
from airflow.sdk.types import OutletEventAccessorsProtocol
6667

6768
# NOTE: Please keep this in sync with the following:
6869
# * Context in task_sdk/src/airflow/sdk/definitions/context.py
@@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context:
331332
return cast(Context, new)
332333

333334

334-
def context_get_outlet_events(context: Context) -> OutletEventAccessors:
335+
def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
335336
try:
336337
return context["outlet_events"]
337338
except KeyError:

airflow/utils/operator_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from airflow.utils.types import NOTSET
3030

3131
if TYPE_CHECKING:
32-
from airflow.utils.context import OutletEventAccessors
32+
from airflow.sdk.types import OutletEventAccessorsProtocol
3333

3434
P = ParamSpec("P")
3535
R = TypeVar("R")
@@ -230,7 +230,7 @@ def run(*args, **kwargs): ...
230230

231231
def ExecutionCallableRunner(
232232
func: Callable[P, R],
233-
outlet_events: OutletEventAccessors,
233+
outlet_events: OutletEventAccessorsProtocol,
234234
*,
235235
logger: logging.Logger,
236236
) -> _ExecutionCallableRunner:

providers/edge/src/airflow/providers/edge/example_dags/win_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
if TYPE_CHECKING:
4848
try:
49-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol as TaskInstance
49+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
5050
except ImportError:
5151
from airflow.models import TaskInstance # type: ignore[assignment]
5252
from airflow.utils.context import Context

providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if TYPE_CHECKING:
3333
try:
34-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
34+
from airflow.sdk.types import RuntimeTaskInstanceProtocol
3535
except ImportError:
3636
from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol # type: ignore[assignment]
3737
from airflow.utils.context import Context

task_sdk/src/airflow/sdk/definitions/asset/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,3 +660,12 @@ def as_expression(self) -> Any:
660660
:meta private:
661661
"""
662662
return {"all": [o.as_expression() for o in self.objects]}
663+
664+
665+
@attrs.define
666+
class AssetAliasEvent:
667+
"""Representation of asset event to be triggered by an asset alias."""
668+
669+
source_alias_name: str
670+
dest_asset_key: AssetUniqueKey
671+
extra: dict[str, Any]

task_sdk/src/airflow/sdk/definitions/context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from airflow.models.operator import Operator
2828
from airflow.sdk.definitions.baseoperator import BaseOperator
2929
from airflow.sdk.definitions.dag import DAG
30-
from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol
30+
from airflow.sdk.types import (
31+
DagRunProtocol,
32+
OutletEventAccessorsProtocol,
33+
RuntimeTaskInstanceProtocol,
34+
)
3135

3236

3337
class Context(TypedDict, total=False):
@@ -38,8 +42,7 @@ class Context(TypedDict, total=False):
3842
dag_run: DagRunProtocol
3943
data_interval_end: datetime | None
4044
data_interval_start: datetime | None
41-
# outlet_events: OutletEventAccessors
42-
outlet_events: Any
45+
outlet_events: OutletEventAccessorsProtocol
4346
ds: str
4447
ds_nodash: str
4548
expanded_ti_count: int | None

task_sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from airflow.sdk.definitions.asset import (
2929
Asset,
3030
AssetAlias,
31+
AssetAliasEvent,
3132
AssetAliasUniqueKey,
3233
AssetNameRef,
3334
AssetRef,
@@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool:
174175
return True
175176

176177

177-
@attrs.define
178-
class AssetAliasEvent:
179-
"""Representation of asset event to be triggered by an asset alias."""
180-
181-
source_alias_name: str
182-
dest_asset_key: AssetUniqueKey
183-
extra: dict[str, Any]
184-
185-
186178
@attrs.define
187179
class OutletEventAccessor:
188180
"""Wrapper to access an outlet asset event in template."""

task_sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def get_template_context(self) -> Context:
137137
}
138138
context.update(context_from_server)
139139

140-
# TODO: We should use/move TypeDict from airflow.utils.context.Context
141140
return context
142141

143142
def render_templates(

0 commit comments

Comments
 (0)