diff options
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/ast_util.py')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/ast_util.py | 87 |
1 files changed, 59 insertions, 28 deletions
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index 86e3f56a64..d7453b0781 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -20,7 +20,6 @@ from __future__ import print_function import ast -import collections import gast from tensorflow.contrib.autograph.pyct import anno @@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor): if v != p: return self.no_match() + def matches(node, pattern): """Basic pattern matcher for AST. @@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn): apply_fn(target, values) -def iter_fields(node): - for field in sorted(node._fields): - try: - yield getattr(node, field) - except AttributeError: - pass - - -def iter_child_nodes(node): - for field in iter_fields(node): - if isinstance(field, gast.AST): - yield field - elif isinstance(field, list): - for item in field: - if isinstance(item, gast.AST): - yield item - - -def parallel_walk(node_a, node_b): - todo_a = collections.deque([node_a]) - todo_b = collections.deque([node_b]) - while todo_a and todo_b: - node_a = todo_a.popleft() - node_b = todo_b.popleft() - todo_a.extend(iter_child_nodes(node_a)) - todo_b.extend(iter_child_nodes(node_b)) - yield node_a, node_b +def parallel_walk(node, other): + """Walks two ASTs in parallel. + + The two trees must have identical structure. + + Args: + node: Union[ast.AST, Iterable[ast.AST]] + other: Union[ast.AST, Iterable[ast.AST]] + Yields: + Tuple[ast.AST, ast.AST] + Raises: + ValueError: if the two trees don't have identical structure. + """ + if isinstance(node, (list, tuple)): + node_stack = list(node) + else: + node_stack = [node] + + if isinstance(other, (list, tuple)): + other_stack = list(other) + else: + other_stack = [other] + + while node_stack and other_stack: + assert len(node_stack) == len(other_stack) + n = node_stack.pop() + o = other_stack.pop() + + if (not isinstance(n, (ast.AST, gast.AST)) or + not isinstance(o, (ast.AST, gast.AST)) or + n.__class__.__name__ != o.__class__.__name__): + raise ValueError('inconsistent nodes: {} and {}'.format(n, o)) + + yield n, o + + for f in n._fields: + n_child = getattr(n, f, None) + o_child = getattr(o, f, None) + if f.startswith('__') or n_child is None or o_child is None: + continue + + if isinstance(n_child, (list, tuple)): + if (not isinstance(o_child, (list, tuple)) or + len(n_child) != len(o_child)): + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) + node_stack.extend(n_child) + other_stack.extend(o_child) + + elif isinstance(n_child, (gast.AST, ast.AST)): + node_stack.append(n_child) + other_stack.append(o_child) + + elif n_child != o_child: + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) |