@@ -242,6 +242,8 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
242242 updated_node = self .add_missed_overloads (updated_node , stack_id = ())
243243 # Add any missing @mp_available
244244 updated_node = self .add_missed_mp_available (updated_node , stack_id = ())
245+ # Add any missing literal docstrings
246+ updated_node = self .add_missed_literal_docstrings (updated_node , stack_id = ())
245247 return updated_node
246248
247249 def add_missed_overloads (self , updated_node : Mod_Class_T , stack_id : tuple ) -> Mod_Class_T :
@@ -378,6 +380,83 @@ def add_missed_mp_available(self, updated_node: Mod_Class_T, stack_id: tuple) ->
378380 # cst.IndentedBlock(body=tuple(updated_body))) # type: ignore
379381 return updated_node
380382
383+ def add_missed_literal_docstrings (self , updated_node : Mod_Class_T , stack_id : tuple ) -> Mod_Class_T :
384+ """
385+ Add any missing literal docstrings to the updated_node
386+ """
387+ # Find literal docstrings for this scope
388+ literal_docstrings = {}
389+
390+ # For module level (empty stack_id), use MODULE_KEY
391+ if len (stack_id ) == 0 :
392+ lookup_key = MODULE_KEY
393+ else :
394+ lookup_key = stack_id
395+
396+ if lookup_key in self .annotations and self .annotations [lookup_key ].literal_docstrings :
397+ literal_docstrings = self .annotations [lookup_key ].literal_docstrings
398+
399+ if not literal_docstrings :
400+ return updated_node
401+
402+ # Get the body to modify
403+ if isinstance (updated_node , cst .Module ):
404+ updated_body = list (updated_node .body )
405+ elif isinstance (updated_node , cst .ClassDef ):
406+ updated_body = list (updated_node .body .body )
407+ else :
408+ raise ValueError (f"Unsupported node type: { updated_node } " )
409+
410+ # Find assignment statements and insert docstrings after them
411+ new_body = []
412+ i = 0
413+ while i < len (updated_body ):
414+ stmt = updated_body [i ]
415+ new_body .append (stmt )
416+
417+ # Check if this is an assignment statement for a literal that has a docstring
418+ if isinstance (stmt , cst .SimpleStatementLine ) and len (stmt .body ) == 1 :
419+ literal_name = None
420+
421+ if isinstance (stmt .body [0 ], cst .Assign ):
422+ # Regular assignment: CONST = value
423+ targets = stmt .body [0 ].targets
424+ if len (targets ) == 1 and isinstance (targets [0 ].target , cst .Name ):
425+ literal_name = targets [0 ].target .value
426+ elif isinstance (stmt .body [0 ], cst .AnnAssign ):
427+ # Annotated assignment: CONST: int = value
428+ if isinstance (stmt .body [0 ].target , cst .Name ):
429+ literal_name = stmt .body [0 ].target .value
430+
431+ # If we found a literal with a docstring, check if docstring is missing
432+ if literal_name and literal_name in literal_docstrings :
433+ # Check if the next statement is already a docstring for this literal
434+ has_existing_docstring = False
435+ if i + 1 < len (updated_body ):
436+ next_stmt = updated_body [i + 1 ]
437+ if (isinstance (next_stmt , cst .SimpleStatementLine ) and
438+ len (next_stmt .body ) == 1 and
439+ isinstance (next_stmt .body [0 ], cst .Expr ) and
440+ isinstance (next_stmt .body [0 ].value , cst .SimpleString )):
441+ has_existing_docstring = True
442+
443+ # If no existing docstring, add the one from doc_stub
444+ if not has_existing_docstring and self .copy_docstr :
445+ docstr_node = literal_docstrings [literal_name ]
446+ new_body .append (docstr_node )
447+ log .trace (f"Added literal docstring for { literal_name } " )
448+
449+ i += 1
450+
451+ # Update the node with the new body
452+ if isinstance (updated_node , cst .Module ):
453+ updated_node = updated_node .with_changes (body = tuple (new_body ))
454+ elif isinstance (updated_node , cst .ClassDef ):
455+ new_indented_body = updated_node .body .with_changes (body = tuple (new_body ))
456+ updated_node = updated_node .with_changes (body = new_indented_body )
457+
458+ return updated_node
459+
381460 def locate_function_by_name (self , overload , updated_body ):
382461 """locate the (last) function by name"""
383462 matched = False
@@ -426,6 +505,8 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
426505 updated_node = self .add_missed_overloads (updated_node , stack_id )
427506 # Add any missing @mp_available
428507 updated_node = self .add_missed_mp_available (updated_node , stack_id )
508+ # Add any missing literal docstrings
509+ updated_node = self .add_missed_literal_docstrings (updated_node , stack_id )
429510 return updated_node
430511
431512 # ------------------------------------------------------------------------
0 commit comments