Skip to content

Commit 324301b

Browse files
committed
refactor(utils/decorators): rewrite remove task decorator to use ast
1 parent 9285cc7 commit 324301b

4 files changed

Lines changed: 56 additions & 54 deletions

File tree

airflow/utils/decorators.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,54 +18,55 @@
1818
from __future__ import annotations
1919

2020
import sys
21-
from collections import deque
2221
from typing import Callable, TypeVar
2322

23+
import libcst as cst
24+
2425
T = TypeVar("T", bound=Callable)
2526

2627

28+
class _TaskDecoratorRemover(cst.CSTTransformer):
29+
def __init__(self, task_decorator_name):
30+
self.decorators_to_remove = {
31+
"setup",
32+
"teardown",
33+
"task.skip_if",
34+
"task.run_if",
35+
task_decorator_name.strip("@"),
36+
}
37+
38+
def _is_task_decorator(self, decorator: cst.Decorator) -> bool:
39+
if isinstance(decorator.decorator, cst.Name):
40+
return decorator.decorator.value in self.decorators_to_remove
41+
elif isinstance(decorator.decorator, cst.Attribute):
42+
if isinstance(decorator.decorator.value, cst.Name):
43+
return (
44+
f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}"
45+
in self.decorators_to_remove
46+
)
47+
elif isinstance(decorator.decorator, cst.Call):
48+
return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func))
49+
return False
50+
51+
def leave_FunctionDef(
52+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
53+
) -> cst.FunctionDef:
54+
new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)]
55+
if len(new_decorators) == len(updated_node.decorators):
56+
return updated_node
57+
return updated_node.with_changes(decorators=new_decorators)
58+
59+
2760
def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
2861
"""
2962
Remove @task or similar decorators as well as @setup and @teardown.
3063
3164
:param python_source: The python source code
3265
:param task_decorator_name: the decorator name
33-
34-
TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
3566
"""
36-
37-
def _remove_task_decorator(py_source, decorator_name):
38-
# if no line starts with @decorator_name, we can early exit
39-
for line in py_source.split("\n"):
40-
if line.startswith(decorator_name):
41-
break
42-
else:
43-
return python_source
44-
split = python_source.split(decorator_name, 1)
45-
before_decorator, after_decorator = split[0], split[1]
46-
if after_decorator[0] == "(":
47-
after_decorator = _balance_parens(after_decorator)
48-
if after_decorator[0] == "\n":
49-
after_decorator = after_decorator[1:]
50-
return before_decorator + after_decorator
51-
52-
decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
53-
for decorator in decorators:
54-
python_source = _remove_task_decorator(python_source, decorator)
55-
return python_source
56-
57-
58-
def _balance_parens(after_decorator):
59-
num_paren = 1
60-
after_decorator = deque(after_decorator)
61-
after_decorator.popleft()
62-
while num_paren:
63-
current = after_decorator.popleft()
64-
if current == "(":
65-
num_paren = num_paren + 1
66-
elif current == ")":
67-
num_paren = num_paren - 1
68-
return "".join(after_decorator)
67+
source_tree = cst.parse_module(python_source)
68+
modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name))
69+
return modified_tree.code
6970

7071

7172
class _autostacklevel_warn:

hatch_build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@
391391
"jinja2>=3.0.0",
392392
"jsonschema>=4.18.0",
393393
"lazy-object-proxy>=1.2.0",
394+
"libcst >=1.1.0",
394395
"linkify-it-py>=2.0.0",
395396
"lockfile>=0.12.2",
396397
"markdown-it-py>=2.1.0",

providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,25 +192,25 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
192192
)
193193

194194
def test_remove_task_decorator(self):
195-
py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs'
195+
py_source = '@task.virtualenv(serializer="dill")\ndef f():\n import funcsigs'
196196
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
197-
assert res == "def f():\nimport funcsigs"
197+
assert res == "def f():\n import funcsigs"
198198

199199
def test_remove_decorator_no_parens(self):
200-
py_source = "@task.virtualenv\ndef f():\nimport funcsigs"
200+
py_source = "@task.virtualenv\ndef f():\n import funcsigs"
201201
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
202-
assert res == "def f():\nimport funcsigs"
202+
assert res == "def f():\n import funcsigs"
203203

204204
def test_remove_decorator_including_comment(self):
205-
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs"
205+
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\n import funcsigs"
206206
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
207-
assert res == "def f():\n# @task.virtualenv\nimport funcsigs"
207+
assert res == "def f():\n# @task.virtualenv\n import funcsigs"
208208

209209
def test_remove_decorator_nested(self):
210-
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs"
210+
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\n import funcsigs"
211211
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
212-
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
212+
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
213213

214-
py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs"
214+
py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\n import funcsigs"
215215
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
216-
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
216+
assert res == "@foo\n@bar\ndef f():\n import funcsigs"

tests/utils/test_preexisting_python_virtualenv_decorator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222

2323
class TestExternalPythonDecorator:
2424
def test_remove_task_decorator(self):
25-
py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs'
25+
py_source = '@task.external_python(serializer="dill")\ndef f():\n import funcsigs'
2626
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
27-
assert res == "def f():\nimport funcsigs"
27+
assert res == "def f():\n import funcsigs"
2828

2929
def test_remove_decorator_no_parens(self):
30-
py_source = "@task.external_python\ndef f():\nimport funcsigs"
30+
py_source = "@task.external_python\ndef f():\n import funcsigs"
3131
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
32-
assert res == "def f():\nimport funcsigs"
32+
assert res == "def f():\n import funcsigs"
3333

3434
def test_remove_decorator_nested(self):
35-
py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs"
35+
py_source = "@foo\n@task.external_python\n@bar\ndef f():\n import funcsigs"
3636
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
37-
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
37+
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
3838

39-
py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs"
39+
py_source = "@foo\n@task.external_python()\n@bar\ndef f():\n import funcsigs"
4040
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
41-
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
41+
assert res == "@foo\n@bar\ndef f():\n import funcsigs"

0 commit comments

Comments
 (0)