Spaces:
Running
Running
from __future__ import annotations | |
from typing import Union | |
import libcst as cst | |
from libcst._nodes.module import Module | |
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] | |
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: | |
"""Extracts the docstring from the body of a node. | |
Args: | |
body: The body of a node. | |
Returns: | |
The docstring statement if it exists, None otherwise. | |
""" | |
if isinstance(body, cst.Module): | |
body = body.body | |
else: | |
body = body.body.body | |
if not body: | |
return | |
statement = body[0] | |
if not isinstance(statement, cst.SimpleStatementLine): | |
return | |
expr = statement | |
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): | |
if len(expr.body) == 0: | |
return None | |
expr = expr.body[0] | |
if not isinstance(expr, cst.Expr): | |
return None | |
val = expr.value | |
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): | |
return None | |
evaluated_value = val.evaluated_value | |
if isinstance(evaluated_value, bytes): | |
return None | |
return statement | |
def has_decorator(node: DocstringNode, name: str) -> bool: | |
return hasattr(node, "decorators") and any( | |
(hasattr(i.decorator, "value") and i.decorator.value == name) | |
or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name) | |
for i in node.decorators | |
) | |
class DocstringCollector(cst.CSTVisitor): | |
"""A visitor class for collecting docstrings from a CST. | |
Attributes: | |
stack: A list to keep track of the current path in the CST. | |
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. | |
""" | |
def __init__(self): | |
self.stack: list[str] = [] | |
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} | |
def visit_Module(self, node: cst.Module) -> bool | None: | |
self.stack.append("") | |
def leave_Module(self, node: cst.Module) -> None: | |
return self._leave(node) | |
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: | |
self.stack.append(node.name.value) | |
def leave_ClassDef(self, node: cst.ClassDef) -> None: | |
return self._leave(node) | |
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: | |
self.stack.append(node.name.value) | |
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: | |
return self._leave(node) | |
def _leave(self, node: DocstringNode) -> None: | |
key = tuple(self.stack) | |
self.stack.pop() | |
if has_decorator(node, "overload"): | |
return | |
statement = get_docstring_statement(node) | |
if statement: | |
self.docstrings[key] = statement | |
class DocstringTransformer(cst.CSTTransformer): | |
"""A transformer class for replacing docstrings in a CST. | |
Attributes: | |
stack: A list to keep track of the current path in the CST. | |
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. | |
""" | |
def __init__( | |
self, | |
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], | |
): | |
self.stack: list[str] = [] | |
self.docstrings = docstrings | |
def visit_Module(self, node: cst.Module) -> bool | None: | |
self.stack.append("") | |
def leave_Module(self, original_node: Module, updated_node: Module) -> Module: | |
return self._leave(original_node, updated_node) | |
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: | |
self.stack.append(node.name.value) | |
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: | |
return self._leave(original_node, updated_node) | |
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: | |
self.stack.append(node.name.value) | |
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: | |
return self._leave(original_node, updated_node) | |
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: | |
key = tuple(self.stack) | |
self.stack.pop() | |
if has_decorator(updated_node, "overload"): | |
return updated_node | |
statement = self.docstrings.get(key) | |
if not statement: | |
return updated_node | |
original_statement = get_docstring_statement(original_node) | |
if isinstance(updated_node, cst.Module): | |
body = updated_node.body | |
if original_statement: | |
return updated_node.with_changes(body=(statement, *body[1:])) | |
else: | |
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) | |
return updated_node | |
body = updated_node.body.body[1:] if original_statement else updated_node.body.body | |
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) | |
def merge_docstring(code: str, documented_code: str) -> str: | |
"""Merges the docstrings from the documented code into the original code. | |
Args: | |
code: The original code. | |
documented_code: The documented code. | |
Returns: | |
The original code with the docstrings from the documented code. | |
""" | |
code_tree = cst.parse_module(code) | |
documented_code_tree = cst.parse_module(documented_code) | |
visitor = DocstringCollector() | |
documented_code_tree.visit(visitor) | |
transformer = DocstringTransformer(visitor.docstrings) | |
modified_tree = code_tree.visit(transformer) | |
return modified_tree.code | |