Skip to content

Commit 4ae5d25

Browse files
committed
feat: Support expression modernization
1 parent 009d903 commit 4ae5d25

2 files changed

Lines changed: 119 additions & 1 deletion

File tree

src/griffe/_internal/expressions.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ def _expr_as_dict(expression: Expr, **kwargs: Any) -> dict[str, Any]:
150150
return fields
151151

152152

153+
_modern_types = {
154+
"typing.Tuple": "tuple",
155+
"typing.Dict": "dict",
156+
"typing.List": "list",
157+
"typing.Set": "set",
158+
}
159+
153160
# YORE: EOL 3.9: Remove block.
154161
_dataclass_opts: dict[str, bool] = {}
155162
if sys.version_info >= (3, 10):
@@ -265,6 +272,11 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
265272
yield "."
266273
yield from _yield(value, flat=flat, outer_precedence=precedence)
267274

275+
def modernize(self) -> ExprName | ExprAttribute:
276+
if modern := _modern_types.get(self.canonical_path):
277+
return ExprName(modern, parent=self.last.parent)
278+
return self
279+
268280
def append(self, value: ExprName) -> None:
269281
"""Append a name to this attribute.
270282
@@ -716,6 +728,11 @@ def __eq__(self, other: object) -> bool:
716728
def iterate(self, *, flat: bool = True) -> Iterator[ExprName]: # noqa: ARG002
717729
yield self
718730

731+
def modernize(self) -> ExprName:
732+
if modern := _modern_types.get(self.canonical_path):
733+
return ExprName(modern, parent=self.parent)
734+
return self
735+
719736
@property
720737
def path(self) -> str:
721738
"""The full, resolved name.
@@ -878,7 +895,7 @@ class ExprSubscript(Expr):
878895

879896
left: str | Expr
880897
"""Left part."""
881-
slice: Expr
898+
slice: str | Expr
882899
"""Slice part."""
883900

884901
def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
@@ -888,6 +905,33 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
888905
yield from _yield(self.slice, flat=flat, outer_precedence=_OperatorPrecedence.NONE)
889906
yield "]"
890907

908+
@staticmethod
909+
def _to_binop(elements: Sequence[Expr], op: str) -> ExprBinOp:
910+
if len(elements) == 2: # noqa: PLR2004
911+
left, right = elements
912+
if isinstance(left, Expr):
913+
left = left.modernize()
914+
if isinstance(right, Expr):
915+
right = right.modernize()
916+
return ExprBinOp(left=left, operator=op, right=right)
917+
918+
left = ExprSubscript._to_binop(elements[:-1], op=op)
919+
right = elements[-1]
920+
if isinstance(right, Expr):
921+
right = right.modernize()
922+
return ExprBinOp(left=left, operator=op, right=right)
923+
924+
def modernize(self) -> ExprBinOp | ExprSubscript:
925+
if self.canonical_path == "typing.Union":
926+
return self._to_binop(self.slice.elements, op="|") # type: ignore[union-attr]
927+
if self.canonical_path == "typing.Optional":
928+
left = self.slice if isinstance(self.slice, str) else self.slice.modernize()
929+
return ExprBinOp(left=left, operator="|", right="None")
930+
return ExprSubscript(
931+
left=self.left if isinstance(self.left, str) else self.left.modernize(),
932+
slice=self.slice if isinstance(self.slice, str) else self.slice.modernize(),
933+
)
934+
891935
@property
892936
def path(self) -> str:
893937
"""The path of this subscript's left part."""
@@ -922,6 +966,12 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
922966
if not self.implicit:
923967
yield ")"
924968

969+
def modernize(self) -> ExprTuple:
970+
return ExprTuple(
971+
elements=[el if isinstance(el, str) else el.modernize() for el in self.elements],
972+
implicit=self.implicit,
973+
)
974+
925975

926976
# YORE: EOL 3.9: Replace `**_dataclass_opts` with `slots=True` within line.
927977
@dataclass(eq=True, **_dataclass_opts)

tests/test_expressions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,74 @@ def test_length_one_tuple_as_string() -> None:
9595
assert str(module["x"].value) == "('a',)"
9696

9797

98+
@pytest.mark.parametrize(
99+
("annotation", "modernized"),
100+
[
101+
("Union[str, int, float]", "str | int | float"),
102+
("typing.Union[str, int, float]", "str | int | float"),
103+
("Union[Tuple[str, ...], Dict[str, int]]", "tuple[str, ...] | dict[str, int]"),
104+
("typing.Union[typing.Tuple[str, ...], typing.Dict[str, int]]", "tuple[str, ...] | dict[str, int]"),
105+
("Tuple[List[Dict[str, Set[str]]]]", "tuple[list[dict[str, set[str]]]]"),
106+
("typing.Tuple[typing.List[typing.Dict[str, typing.Set[str]]]]", "tuple[list[dict[str, set[str]]]]"),
107+
("Optional[Tuple[List[bool]]]", "tuple[list[bool]] | None"),
108+
("typing.Optional[typing.Tuple[typing.List[bool]]]", "tuple[list[bool]] | None"),
109+
],
110+
)
111+
def test_modernizing_specific_expressions(annotation: str, modernized: str) -> None:
112+
"""Modernize expressions correctly.
113+
114+
Parameters:
115+
annotation: Original annotation (parametrized).
116+
modernized: Expected modernized annotation (parametrized).
117+
"""
118+
with temporary_visited_module(
119+
f"""
120+
import typing
121+
from typing import Union, Optional, Tuple, Dict, List, Set, Literal
122+
a: {annotation}
123+
""",
124+
) as module:
125+
expression = module["a"].annotation
126+
assert str(expression.modernize()) == modernized
127+
128+
129+
@pytest.mark.parametrize(
130+
"annotation",
131+
[
132+
"typing.Literal['s']",
133+
"Literal['s']",
134+
],
135+
)
136+
def test_handling_modernization_without_crashing(annotation: str) -> None:
137+
"""Modernizing expressions never crashes.
138+
139+
Parameters:
140+
annotation: Original annotation (parametrized).
141+
"""
142+
with temporary_visited_module(
143+
f"""
144+
import typing
145+
from typing import Union, Optional, Tuple, Dict, List, Set, Literal
146+
a: {annotation}
147+
""",
148+
) as module:
149+
module["a"].annotation.modernize()
150+
151+
152+
@pytest.mark.parametrize("code", syntax_examples)
153+
def test_modernizing_idempotence(code: str) -> None:
154+
"""Modernize expressions that can't be modernized.
155+
156+
Parameters:
157+
code: An expression (parametrized).
158+
"""
159+
top_node = compile(code, filename="<>", mode="exec", flags=ast.PyCF_ONLY_AST, optimize=2)
160+
expression = get_expression(top_node.body[0].value, parent=Module("module")) # type: ignore[attr-defined]
161+
modernized = expression.modernize() # type: ignore[union-attr]
162+
assert expression == modernized
163+
assert str(expression) == str(modernized)
164+
165+
98166
def test_resolving_init_parameter() -> None:
99167
"""Instance attribute values should resolve to matching parameters.
100168

0 commit comments

Comments
 (0)