Spaces:
Running
Running
""" | |
@Time : 2024/7/24 16:37 | |
@Author : didi | |
@File : utils.py | |
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py | |
""" | |
import ast | |
import traceback | |
from enum import Enum | |
from typing import Dict, Generator, List, Optional, Set, Tuple | |
import tree_sitter_python | |
from tree_sitter import Language, Node, Parser | |
class NodeType(Enum): | |
CLASS = "class_definition" | |
FUNCTION = "function_definition" | |
IMPORT = ["import_statement", "import_from_statement"] | |
IDENTIFIER = "identifier" | |
ATTRIBUTE = "attribute" | |
RETURN = "return_statement" | |
EXPRESSION = "expression_statement" | |
ASSIGNMENT = "assignment" | |
def traverse_tree(node: Node) -> Generator[Node, None, None]: | |
""" | |
Traverse the tree structure starting from the given node. | |
:param node: The root node to start the traversal from. | |
:return: A generator object that yields nodes in the tree. | |
""" | |
cursor = node.walk() | |
depth = 0 | |
visited_children = False | |
while True: | |
if not visited_children: | |
yield cursor.node | |
if not cursor.goto_first_child(): | |
depth += 1 | |
visited_children = True | |
elif cursor.goto_next_sibling(): | |
visited_children = False | |
elif not cursor.goto_parent() or depth == 0: | |
break | |
else: | |
depth -= 1 | |
def syntax_check(code, verbose=False): | |
try: | |
ast.parse(code) | |
return True | |
except (SyntaxError, MemoryError): | |
if verbose: | |
traceback.print_exc() | |
return False | |
def code_extract(text: str) -> str: | |
lines = text.split("\n") | |
longest_line_pair = (0, 0) | |
longest_so_far = 0 | |
for i in range(len(lines)): | |
for j in range(i + 1, len(lines)): | |
current_lines = "\n".join(lines[i : j + 1]) | |
if syntax_check(current_lines): | |
current_length = sum(1 for line in lines[i : j + 1] if line.strip()) | |
if current_length > longest_so_far: | |
longest_so_far = current_length | |
longest_line_pair = (i, j) | |
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) | |
def get_definition_name(node: Node) -> str: | |
for child in node.children: | |
if child.type == NodeType.IDENTIFIER.value: | |
return child.text.decode("utf8") | |
def has_return_statement(node: Node) -> bool: | |
traverse_nodes = traverse_tree(node) | |
for node in traverse_nodes: | |
if node.type == NodeType.RETURN.value: | |
return True | |
return False | |
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: | |
def dfs_get_deps(node: Node, deps: Set[str]) -> None: | |
for child in node.children: | |
if child.type == NodeType.IDENTIFIER.value: | |
deps.add(child.text.decode("utf8")) | |
else: | |
dfs_get_deps(child, deps) | |
name2deps = {} | |
for name, node in nodes: | |
deps = set() | |
dfs_get_deps(node, deps) | |
name2deps[name] = deps | |
return name2deps | |
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: | |
queue = [entrypoint] | |
visited = {entrypoint} | |
while queue: | |
current = queue.pop(0) | |
if current not in call_graph: | |
continue | |
for neighbour in call_graph[current]: | |
if neighbour not in visited: | |
visited.add(neighbour) | |
queue.append(neighbour) | |
return visited | |
def sanitize(code: str, entrypoint: Optional[str] = None) -> str: | |
""" | |
Sanitize and extract relevant parts of the given Python code. | |
This function parses the input code, extracts import statements, class and function definitions, | |
and variable assignments. If an entrypoint is provided, it only includes definitions that are | |
reachable from the entrypoint in the call graph. | |
:param code: The input Python code as a string. | |
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. | |
:return: A sanitized version of the input code, containing only relevant parts. | |
""" | |
code = code_extract(code) | |
code_bytes = bytes(code, "utf8") | |
parser = Parser(Language(tree_sitter_python.language())) | |
tree = parser.parse(code_bytes) | |
class_names = set() | |
function_names = set() | |
variable_names = set() | |
root_node = tree.root_node | |
import_nodes = [] | |
definition_nodes = [] | |
for child in root_node.children: | |
if child.type in NodeType.IMPORT.value: | |
import_nodes.append(child) | |
elif child.type == NodeType.CLASS.value: | |
name = get_definition_name(child) | |
if not (name in class_names or name in variable_names or name in function_names): | |
definition_nodes.append((name, child)) | |
class_names.add(name) | |
elif child.type == NodeType.FUNCTION.value: | |
name = get_definition_name(child) | |
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( | |
child | |
): | |
definition_nodes.append((name, child)) | |
function_names.add(get_definition_name(child)) | |
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: | |
subchild = child.children[0] | |
name = get_definition_name(subchild) | |
if not (name in variable_names or name in function_names or name in class_names): | |
definition_nodes.append((name, subchild)) | |
variable_names.add(name) | |
if entrypoint: | |
name2deps = get_deps(definition_nodes) | |
reacheable = get_function_dependency(entrypoint, name2deps) | |
sanitized_output = b"" | |
for node in import_nodes: | |
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" | |
for pair in definition_nodes: | |
name, node = pair | |
if entrypoint and name not in reacheable: | |
continue | |
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" | |
return sanitized_output[:-1].decode("utf8") | |