Skip to content

Commit cbc9aaf

Browse files
noamcohen97ephraimbuddy
authored andcommitted
Make sure multiple_outputs is inferred correctly even when using TypedDict (#36652)
* Use `issubclass()` to check if return type is a dictionary * Compare type to `typing.Mapping` instead of `typing.Dict` * Add documentation (cherry picked from commit e11b91c)
1 parent 6614348 commit cbc9aaf

3 files changed

Lines changed: 16 additions & 4 deletions

File tree

airflow/decorators/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
Callable,
2828
ClassVar,
2929
Collection,
30-
Dict,
3130
Generic,
3231
Iterator,
3332
Mapping,
@@ -351,7 +350,7 @@ def fake():
351350
except TypeError: # Can't evaluate return type.
352351
return False
353352
ttype = getattr(return_type, "__origin__", return_type)
354-
return ttype is dict or ttype is Dict
353+
return issubclass(ttype, Mapping)
355354

356355
def __attrs_post_init__(self):
357356
if "self" in self.function_signature.parameters:

docs/apache-airflow/tutorial/taskflow.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ Tasks can also infer multiple outputs by using dict Python typing.
428428
def identity_dict(x: int, y: int) -> dict[str, int]:
429429
return {"x": x, "y": y}
430430
431-
By using the typing ``Dict`` for the function return type, the ``multiple_outputs`` parameter
432-
is automatically set to true.
431+
By using the typing ``dict``, or any other class that conforms to the ``typing.Mapping`` protocol,
432+
for the function return type, the ``multiple_outputs`` parameter is automatically set to true.
433433

434434
Note, If you manually set the ``multiple_outputs`` parameter the inference is disabled and
435435
the parameter value is used.

tests/decorators/test_python.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ def identity_dict_with_decorator_call(x: int, y: int) -> resolve(annotation):
9797

9898
assert identity_dict_with_decorator_call(5, 5).operator.multiple_outputs is True
9999

100+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="PEP 589 is implemented in Python 3.8")
101+
def test_infer_multiple_outputs_typed_dict(self):
102+
from typing import TypedDict
103+
104+
class TypeDictClass(TypedDict):
105+
pass
106+
107+
@task_decorator
108+
def t1() -> TypeDictClass:
109+
return {}
110+
111+
assert t1().operator.multiple_outputs is True
112+
100113
def test_infer_multiple_outputs_forward_annotation(self):
101114
if TYPE_CHECKING:
102115

0 commit comments

Comments
 (0)