Skip to content

Commit f45e11d

Browse files
CopilotJosverl
andcommitted
Implement literal docstring merging for module and class level
Co-authored-by: Josverl <[email protected]>
1 parent 4e6fdd5 commit f45e11d

File tree

2 files changed

+142
-2
lines changed

2 files changed

+142
-2
lines changed

src/stubber/codemod/merge_docstub.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
242242
updated_node = self.add_missed_overloads(updated_node, stack_id=())
243243
# Add any missing @mp_available
244244
updated_node = self.add_missed_mp_available(updated_node, stack_id=())
245+
# Add any missing literal docstrings
246+
updated_node = self.add_missed_literal_docstrings(updated_node, stack_id=())
245247
return updated_node
246248

247249
def add_missed_overloads(self, updated_node: Mod_Class_T, stack_id: tuple) -> Mod_Class_T:
@@ -378,6 +380,83 @@ def add_missed_mp_available(self, updated_node: Mod_Class_T, stack_id: tuple) ->
378380
# cst.IndentedBlock(body=tuple(updated_body))) # type: ignore
379381
return updated_node
380382

383+
def add_missed_literal_docstrings(self, updated_node: Mod_Class_T, stack_id: tuple) -> Mod_Class_T:
384+
"""
385+
Add any missing literal docstrings to the updated_node
386+
"""
387+
# Find literal docstrings for this scope
388+
literal_docstrings = {}
389+
390+
# For module level (empty stack_id), use MODULE_KEY
391+
if len(stack_id) == 0:
392+
lookup_key = MODULE_KEY
393+
else:
394+
lookup_key = stack_id
395+
396+
if lookup_key in self.annotations and self.annotations[lookup_key].literal_docstrings:
397+
literal_docstrings = self.annotations[lookup_key].literal_docstrings
398+
399+
if not literal_docstrings:
400+
return updated_node
401+
402+
# Get the body to modify
403+
if isinstance(updated_node, cst.Module):
404+
updated_body = list(updated_node.body)
405+
elif isinstance(updated_node, cst.ClassDef):
406+
updated_body = list(updated_node.body.body)
407+
else:
408+
raise ValueError(f"Unsupported node type: {updated_node}")
409+
410+
# Find assignment statements and insert docstrings after them
411+
new_body = []
412+
i = 0
413+
while i < len(updated_body):
414+
stmt = updated_body[i]
415+
new_body.append(stmt)
416+
417+
# Check if this is an assignment statement for a literal that has a docstring
418+
if isinstance(stmt, cst.SimpleStatementLine) and len(stmt.body) == 1:
419+
literal_name = None
420+
421+
if isinstance(stmt.body[0], cst.Assign):
422+
# Regular assignment: CONST = value
423+
targets = stmt.body[0].targets
424+
if len(targets) == 1 and isinstance(targets[0].target, cst.Name):
425+
literal_name = targets[0].target.value
426+
elif isinstance(stmt.body[0], cst.AnnAssign):
427+
# Annotated assignment: CONST: int = value
428+
if isinstance(stmt.body[0].target, cst.Name):
429+
literal_name = stmt.body[0].target.value
430+
431+
# If we found a literal with a docstring, check if docstring is missing
432+
if literal_name and literal_name in literal_docstrings:
433+
# Check if the next statement is already a docstring for this literal
434+
has_existing_docstring = False
435+
if i + 1 < len(updated_body):
436+
next_stmt = updated_body[i + 1]
437+
if (isinstance(next_stmt, cst.SimpleStatementLine) and
438+
len(next_stmt.body) == 1 and
439+
isinstance(next_stmt.body[0], cst.Expr) and
440+
isinstance(next_stmt.body[0].value, cst.SimpleString)):
441+
has_existing_docstring = True
442+
443+
# If no existing docstring, add the one from doc_stub
444+
if not has_existing_docstring and self.copy_docstr:
445+
docstr_node = literal_docstrings[literal_name]
446+
new_body.append(docstr_node)
447+
log.trace(f"Added literal docstring for {literal_name}")
448+
449+
i += 1
450+
451+
# Update the node with the new body
452+
if isinstance(updated_node, cst.Module):
453+
updated_node = updated_node.with_changes(body=tuple(new_body))
454+
elif isinstance(updated_node, cst.ClassDef):
455+
new_indented_body = updated_node.body.with_changes(body=tuple(new_body))
456+
updated_node = updated_node.with_changes(body=new_indented_body)
457+
458+
return updated_node
459+
381460
def locate_function_by_name(self, overload, updated_body):
382461
"""locate the (last) function by name"""
383462
matched = False
@@ -426,6 +505,8 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
426505
updated_node = self.add_missed_overloads(updated_node, stack_id)
427506
# Add any missing @mp_available
428507
updated_node = self.add_missed_mp_available(updated_node, stack_id)
508+
# Add any missing literal docstrings
509+
updated_node = self.add_missed_literal_docstrings(updated_node, stack_id)
429510
return updated_node
430511

431512
# ------------------------------------------------------------------------

src/stubber/typing_collector.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class AnnoValue:
3333
"function / method overloads read from the docstub source"
3434
mp_available: List[TypeInfo] = field(default_factory=list)
3535
"function / method `overloads` read from the docstub source"
36+
literal_docstrings: Dict[str, cst.SimpleStatementLine] = field(default_factory=dict)
37+
"literal/constant name -> docstring node mappings for literal docstrings"
3638

3739

3840
class TransformError(Exception):
@@ -85,12 +87,58 @@ def __init__(self):
8587
] = {}
8688
self.comments: List[str] = []
8789

90+
def _collect_literal_docstrings(self, body: Sequence[cst.BaseStatement]) -> Dict[str, cst.SimpleStatementLine]:
91+
"""
92+
Collect literal docstrings from a sequence of statements.
93+
Looks for patterns like:
94+
CONSTANT = value
95+
'''docstring for constant'''
96+
"""
97+
literal_docstrings = {}
98+
99+
for i, stmt in enumerate(body):
100+
# Look for assignment statements
101+
if isinstance(stmt, cst.SimpleStatementLine) and len(stmt.body) == 1:
102+
if isinstance(stmt.body[0], (cst.Assign, cst.AnnAssign)):
103+
# Get the literal name
104+
literal_name = None
105+
if isinstance(stmt.body[0], cst.Assign):
106+
# Handle regular assignment: CONST = value
107+
targets = stmt.body[0].targets
108+
if len(targets) == 1 and isinstance(targets[0].target, cst.Name):
109+
literal_name = targets[0].target.value
110+
elif isinstance(stmt.body[0], cst.AnnAssign):
111+
# Handle annotated assignment: CONST: int = value
112+
if isinstance(stmt.body[0].target, cst.Name):
113+
literal_name = stmt.body[0].target.value
114+
115+
# Check if the next statement is a docstring
116+
if literal_name and i + 1 < len(body):
117+
next_stmt = body[i + 1]
118+
if (isinstance(next_stmt, cst.SimpleStatementLine) and
119+
len(next_stmt.body) == 1 and
120+
isinstance(next_stmt.body[0], cst.Expr) and
121+
isinstance(next_stmt.body[0].value, cst.SimpleString)):
122+
# Found a literal with a following docstring
123+
literal_docstrings[literal_name] = next_stmt
124+
125+
return literal_docstrings
126+
88127
# ------------------------------------------------------------
89128
def visit_Module(self, node: cst.Module) -> bool:
90-
"""Store the module docstring"""
129+
"""Store the module docstring and collect literal docstrings"""
130+
# Store module docstring
91131
docstr = node.get_docstring()
92132
if docstr:
93133
self.annotations[MODULE_KEY] = AnnoValue(docstring=docstr)
134+
135+
# Collect module-level literal docstrings
136+
literal_docstrings = self._collect_literal_docstrings(node.body)
137+
if literal_docstrings:
138+
if MODULE_KEY not in self.annotations:
139+
self.annotations[MODULE_KEY] = AnnoValue()
140+
self.annotations[MODULE_KEY].literal_docstrings.update(literal_docstrings)
141+
94142
return True
95143

96144
def visit_Comment(self, node: cst.Comment) -> None:
@@ -108,6 +156,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
108156
"""
109157
collect info from a classdef:
110158
- name, decorators, docstring
159+
- class-level literal docstrings
111160
"""
112161
# "Store the class docstring
113162
docstr_node = self.update_append_first_node(node)
@@ -120,7 +169,17 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
120169
def_type="classdef",
121170
def_node=node,
122171
)
123-
self.annotations[tuple(self.stack)] = AnnoValue(type_info=ti)
172+
173+
# Collect class-level literal docstrings
174+
literal_docstrings = {}
175+
if isinstance(node.body, cst.IndentedBlock):
176+
literal_docstrings = self._collect_literal_docstrings(node.body.body)
177+
178+
anno_value = AnnoValue(type_info=ti)
179+
if literal_docstrings:
180+
anno_value.literal_docstrings.update(literal_docstrings)
181+
182+
self.annotations[tuple(self.stack)] = anno_value
124183

125184
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
126185
"""remove the class name from the stack"""

0 commit comments

Comments
 (0)