|
"Handle AST objects." |
|
|
|
import ast |
|
|
|
from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union |
|
|
|
|
|
import asdl |
|
import attr |
|
|
|
|
|
class ASTWrapperVisitor(asdl.VisitorBase): |
|
'''Used by ASTWrapper to collect information. |
|
|
|
- put constructors in one place. |
|
- checks that all fields have names. |
|
- get all optional fields. |
|
''' |
|
|
|
def __init__(self): |
|
|
|
super(ASTWrapperVisitor, self).__init__() |
|
self.constructors = {} |
|
self.sum_types = {} |
|
self.product_types = {} |
|
self.fieldless_constructors = {} |
|
|
|
def visitModule(self, mod): |
|
|
|
for dfn in mod.dfns: |
|
self.visit(dfn) |
|
|
|
def visitType(self, type_): |
|
|
|
self.visit(type_.value, str(type_.name)) |
|
|
|
def visitSum(self, sum_, name): |
|
|
|
self.sum_types[name] = sum_ |
|
for t in sum_.types: |
|
self.visit(t, name) |
|
|
|
def visitConstructor(self, cons, _name): |
|
|
|
assert cons.name not in self.constructors |
|
self.constructors[cons.name] = cons |
|
if not cons.fields: |
|
self.fieldless_constructors[cons.name] = cons |
|
for f in cons.fields: |
|
self.visit(f, cons.name) |
|
|
|
def visitField(self, field, name): |
|
|
|
|
|
if field.name is None: |
|
raise ValueError('Field of type {} in {} lacks name'.format( |
|
field.type, name)) |
|
|
|
def visitProduct(self, prod, name): |
|
|
|
self.product_types[name] = prod |
|
for f in prod.fields: |
|
self.visit(f, name) |
|
|
|
|
|
SingularType = Union[asdl.Constructor, asdl.Product] |
|
|
|
|
|
class ASTWrapper(object): |
|
'''Provides helper methods on the ASDL AST.''' |
|
|
|
default_primitive_type_checkers = { |
|
'identifier': lambda x: isinstance(x, str), |
|
'int': lambda x: isinstance(x, int), |
|
'string': lambda x: isinstance(x, str), |
|
'bytes': lambda x: isinstance(x, bytes), |
|
'object': lambda x: isinstance(x, object), |
|
'singleton': lambda x: x is True or x is False or x is None |
|
} |
|
|
|
|
|
|
|
def __init__(self, ast_def, custom_primitive_type_checkers={}): |
|
|
|
self.ast_def = ast_def |
|
|
|
visitor = ASTWrapperVisitor() |
|
visitor.visit(ast_def) |
|
|
|
self.constructors = visitor.constructors |
|
self.sum_types = visitor.sum_types |
|
self.product_types = visitor.product_types |
|
self.seq_fragment_constructors = {} |
|
self.primitive_type_checkers = { |
|
**self.default_primitive_type_checkers, |
|
**custom_primitive_type_checkers |
|
} |
|
self.custom_primitive_types = set(custom_primitive_type_checkers.keys()) |
|
self.primitive_types = set(self.primitive_type_checkers.keys()) |
|
|
|
|
|
|
|
self.singular_types = {} |
|
self.singular_types.update(self.constructors) |
|
self.singular_types.update(self.product_types) |
|
|
|
|
|
self.sum_type_vocabs = { |
|
name: sorted(t.name for t in sum_type.types) |
|
for name, sum_type in self.sum_types.items() |
|
} |
|
self.constructor_to_sum_type = { |
|
constructor.name: name |
|
for name, sum_type in self.sum_types.items() |
|
for constructor in sum_type.types |
|
} |
|
self.seq_fragment_constructor_to_sum_type = { |
|
constructor.name: name |
|
for name, sum_type in self.sum_types.items() |
|
for constructor in sum_type.types |
|
} |
|
self.fieldless_constructors = sorted( |
|
visitor.fieldless_constructors.keys()) |
|
|
|
@property |
|
def types(self): |
|
|
|
return self.ast_def.types |
|
|
|
@property |
|
def root_type(self): |
|
|
|
return self._root_type |
|
|
|
def add_sum_type(self, name, sum_type): |
|
assert name not in self.sum_types |
|
self.sum_types[name] = sum_type |
|
self.types[name] = sum_type |
|
|
|
for type_ in sum_type.types: |
|
self._add_constructor(name, type_) |
|
|
|
def add_constructors_to_sum_type(self, sum_type_name, constructors): |
|
for constructor in constructors: |
|
self._add_constructor(sum_type_name, constructor) |
|
self.sum_types[sum_type_name].types += constructors |
|
|
|
def remove_product_type(self, product_type_name): |
|
self.singular_types.pop(product_type_name) |
|
self.product_types.pop(product_type_name) |
|
self.types.pop(product_type_name) |
|
|
|
def add_seq_fragment_type(self, sum_type_name, constructors): |
|
for constructor in constructors: |
|
|
|
self._add_constructor(sum_type_name, constructor) |
|
|
|
sum_type = self.sum_types[sum_type_name] |
|
if not hasattr(sum_type, 'seq_fragment_types'): |
|
sum_type.seq_fragment_types = [] |
|
sum_type.seq_fragment_types += constructors |
|
|
|
def _add_constructor(self, sum_type_name, constructor): |
|
assert constructor.name not in self.constructors |
|
self.constructors[constructor.name] = constructor |
|
assert constructor.name not in self.singular_types |
|
self.singular_types[constructor.name] = constructor |
|
assert constructor.name not in self.constructor_to_sum_type |
|
self.constructor_to_sum_type[constructor.name] = sum_type_name |
|
|
|
if not constructor.fields: |
|
self.fieldless_constructors.append(constructor.name) |
|
self.fieldless_constructors.sort() |
|
|
|
def verify_ast(self, node, expected_type=None, field_path=(), is_seq=False): |
|
|
|
|
|
'''Checks that `node` conforms to the current ASDL.''' |
|
if node is None: |
|
raise ValueError('node is None. path: {}'.format(field_path)) |
|
if not isinstance(node, dict): |
|
raise ValueError('node is type {}. path: {}'.format( |
|
type(node), field_path)) |
|
|
|
node_type = node['_type'] |
|
if expected_type is not None: |
|
sum_product = self.types[expected_type] |
|
if isinstance(sum_product, asdl.Product): |
|
if node_type != expected_type: |
|
raise ValueError( |
|
'Expected type {}, but instead saw {}. path: {}'.format( |
|
expected_type, node_type, field_path)) |
|
elif isinstance(sum_product, asdl.Sum): |
|
possible_names = [t.name |
|
for t in sum_product.types] |
|
if is_seq: |
|
possible_names += [t.name for t in getattr(sum_product, 'seq_fragment_types', [])] |
|
if node_type not in possible_names: |
|
raise ValueError( |
|
'Expected one of {}, but instead saw {}. path: {}'.format( |
|
', '.join(possible_names), node_type, field_path)) |
|
|
|
else: |
|
raise ValueError('Unexpected type in ASDL: {}'.format(sum_product)) |
|
|
|
if node_type in self.types: |
|
|
|
sum_product = self.types[node_type] |
|
if isinstance(sum_product, asdl.Sum): |
|
raise ValueError('sum type {} not allowed as node type. path: {}'. |
|
format(node_type, field_path)) |
|
fields_to_check = sum_product.fields |
|
elif node_type in self.constructors: |
|
fields_to_check = self.constructors[node_type].fields |
|
else: |
|
raise ValueError('Unknown node_type {}. path: {}'.format(node_type, |
|
field_path)) |
|
|
|
for field in fields_to_check: |
|
|
|
|
|
|
|
|
|
|
|
if field.name not in node: |
|
if field.opt or field.seq: |
|
continue |
|
raise ValueError('required field {} is missing. path: {}'.format( |
|
field.name, field_path)) |
|
|
|
if field.seq and field.name in node and not isinstance( |
|
node[field.name], (list, tuple)): |
|
raise ValueError('sequential field {} is not sequence. path: {}'. |
|
format(field.name, field_path)) |
|
|
|
|
|
items = node.get(field.name, |
|
()) if field.seq else (node.get(field.name), ) |
|
|
|
|
|
if field.type in self.primitive_type_checkers: |
|
check = self.primitive_type_checkers[field.type] |
|
else: |
|
|
|
check = lambda n: self.verify_ast(n, field.type, field_path + (field.name, ), is_seq=field.seq) |
|
|
|
for item in items: |
|
assert check(item) |
|
return True |
|
|
|
def find_all_descendants_of_type(self, tree, type, descend_pred=lambda field: True): |
|
queue = [tree] |
|
while queue: |
|
node = queue.pop() |
|
if not isinstance(node, dict): |
|
continue |
|
for field_info in self.singular_types[node['_type']].fields: |
|
if field_info.opt and field_info.name not in node: |
|
continue |
|
if not descend_pred(field_info): |
|
continue |
|
|
|
if field_info.seq: |
|
values = node.get(field_info.name, []) |
|
else: |
|
values = [node[field_info.name]] |
|
|
|
if field_info.type == type: |
|
for value in values: |
|
yield value |
|
else: |
|
queue.extend(values) |
|
|
|
|
|
|
|
Node = Dict[str, Any] |
|
|
|
@attr.s |
|
class HoleValuePlaceholder: |
|
id = attr.ib() |
|
is_seq = attr.ib() |
|
is_opt = attr.ib() |
|
|