|
18 | 18 | from __future__ import annotations |
19 | 19 |
|
20 | 20 | import sys |
21 | | -from collections import deque |
22 | 21 | from typing import Callable, TypeVar |
23 | 22 |
|
| 23 | +import libcst as cst |
| 24 | + |
24 | 25 | T = TypeVar("T", bound=Callable) |
25 | 26 |
|
26 | 27 |
|
| 28 | +class _TaskDecoratorRemover(cst.CSTTransformer): |
| 29 | + def __init__(self, task_decorator_name): |
| 30 | + self.decorators_to_remove = { |
| 31 | + "setup", |
| 32 | + "teardown", |
| 33 | + "task.skip_if", |
| 34 | + "task.run_if", |
| 35 | + task_decorator_name.strip("@"), |
| 36 | + } |
| 37 | + |
| 38 | + def _is_task_decorator(self, decorator: cst.Decorator) -> bool: |
| 39 | + if isinstance(decorator.decorator, cst.Name): |
| 40 | + return decorator.decorator.value in self.decorators_to_remove |
| 41 | + elif isinstance(decorator.decorator, cst.Attribute): |
| 42 | + if isinstance(decorator.decorator.value, cst.Name): |
| 43 | + return ( |
| 44 | + f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}" |
| 45 | + in self.decorators_to_remove |
| 46 | + ) |
| 47 | + elif isinstance(decorator.decorator, cst.Call): |
| 48 | + return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func)) |
| 49 | + return False |
| 50 | + |
| 51 | + def leave_FunctionDef( |
| 52 | + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef |
| 53 | + ) -> cst.FunctionDef: |
| 54 | + new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)] |
| 55 | + if len(new_decorators) == len(updated_node.decorators): |
| 56 | + return updated_node |
| 57 | + return updated_node.with_changes(decorators=new_decorators) |
| 58 | + |
| 59 | + |
27 | 60 | def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: |
28 | 61 | """ |
29 | 62 | Remove @task or similar decorators as well as @setup and @teardown. |
30 | 63 |
|
31 | 64 | :param python_source: The python source code |
32 | 65 | :param task_decorator_name: the decorator name |
33 | | -
|
34 | | - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse |
35 | 66 | """ |
36 | | - |
37 | | - def _remove_task_decorator(py_source, decorator_name): |
38 | | - # if no line starts with @decorator_name, we can early exit |
39 | | - for line in py_source.split("\n"): |
40 | | - if line.startswith(decorator_name): |
41 | | - break |
42 | | - else: |
43 | | - return python_source |
44 | | - split = python_source.split(decorator_name, 1) |
45 | | - before_decorator, after_decorator = split[0], split[1] |
46 | | - if after_decorator[0] == "(": |
47 | | - after_decorator = _balance_parens(after_decorator) |
48 | | - if after_decorator[0] == "\n": |
49 | | - after_decorator = after_decorator[1:] |
50 | | - return before_decorator + after_decorator |
51 | | - |
52 | | - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] |
53 | | - for decorator in decorators: |
54 | | - python_source = _remove_task_decorator(python_source, decorator) |
55 | | - return python_source |
56 | | - |
57 | | - |
58 | | -def _balance_parens(after_decorator): |
59 | | - num_paren = 1 |
60 | | - after_decorator = deque(after_decorator) |
61 | | - after_decorator.popleft() |
62 | | - while num_paren: |
63 | | - current = after_decorator.popleft() |
64 | | - if current == "(": |
65 | | - num_paren = num_paren + 1 |
66 | | - elif current == ")": |
67 | | - num_paren = num_paren - 1 |
68 | | - return "".join(after_decorator) |
| 67 | + source_tree = cst.parse_module(python_source) |
| 68 | + modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name)) |
| 69 | + return modified_tree.code |
69 | 70 |
|
70 | 71 |
|
71 | 72 | class _autostacklevel_warn: |
|
0 commit comments