aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/ast_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/ast_util.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util.py87
1 files changed, 59 insertions, 28 deletions
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py
index 86e3f56a64..d7453b0781 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import ast
-import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
@@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor):
if v != p:
return self.no_match()
+
def matches(node, pattern):
"""Basic pattern matcher for AST.
@@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn):
apply_fn(target, values)
-def iter_fields(node):
- for field in sorted(node._fields):
- try:
- yield getattr(node, field)
- except AttributeError:
- pass
-
-
-def iter_child_nodes(node):
- for field in iter_fields(node):
- if isinstance(field, gast.AST):
- yield field
- elif isinstance(field, list):
- for item in field:
- if isinstance(item, gast.AST):
- yield item
-
-
-def parallel_walk(node_a, node_b):
- todo_a = collections.deque([node_a])
- todo_b = collections.deque([node_b])
- while todo_a and todo_b:
- node_a = todo_a.popleft()
- node_b = todo_b.popleft()
- todo_a.extend(iter_child_nodes(node_a))
- todo_b.extend(iter_child_nodes(node_b))
- yield node_a, node_b
+def parallel_walk(node, other):
+ """Walks two ASTs in parallel.
+
+ The two trees must have identical structure.
+
+ Args:
+ node: Union[ast.AST, Iterable[ast.AST]]
+ other: Union[ast.AST, Iterable[ast.AST]]
+ Yields:
+ Tuple[ast.AST, ast.AST]
+ Raises:
+ ValueError: if the two trees don't have identical structure.
+ """
+ if isinstance(node, (list, tuple)):
+ node_stack = list(node)
+ else:
+ node_stack = [node]
+
+ if isinstance(other, (list, tuple)):
+ other_stack = list(other)
+ else:
+ other_stack = [other]
+
+ while node_stack and other_stack:
+ assert len(node_stack) == len(other_stack)
+ n = node_stack.pop()
+ o = other_stack.pop()
+
+ if (not isinstance(n, (ast.AST, gast.AST)) or
+ not isinstance(o, (ast.AST, gast.AST)) or
+ n.__class__.__name__ != o.__class__.__name__):
+ raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
+
+ yield n, o
+
+ for f in n._fields:
+ n_child = getattr(n, f, None)
+ o_child = getattr(o, f, None)
+ if f.startswith('__') or n_child is None or o_child is None:
+ continue
+
+ if isinstance(n_child, (list, tuple)):
+ if (not isinstance(o_child, (list, tuple)) or
+ len(n_child) != len(o_child)):
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
+ node_stack.extend(n_child)
+ other_stack.extend(o_child)
+
+ elif isinstance(n_child, (gast.AST, ast.AST)):
+ node_stack.append(n_child)
+ other_stack.append(o_child)
+
+ elif n_child != o_child:
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))