aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/ast_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/ast_util.py')
-rw-r--r--tensorflow/python/autograph/pyct/ast_util.py313
1 files changed, 313 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
new file mode 100644
index 0000000000..7df3b8858c
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -0,0 +1,313 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST manipulation utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+
+
+class CleanCopier(object):
+ """NodeTransformer-like visitor that copies an AST."""
+
+ def __init__(self, preserve_annos):
+ super(CleanCopier, self).__init__()
+ self.preserve_annos = preserve_annos
+
+ def copy(self, node):
+ """Returns a deep copy of node (excluding some fields, see copy_clean)."""
+
+ if isinstance(node, list):
+ return [self.copy(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(self.copy(n) for n in node)
+ elif not isinstance(node, (gast.AST, ast.AST)):
+ # Assuming everything that's not an AST, list or tuple is a value type
+ # and may simply be assigned.
+ return node
+
+ assert isinstance(node, (gast.AST, ast.AST))
+
+ new_fields = {}
+ for f in node._fields:
+ if not f.startswith('__') and hasattr(node, f):
+ new_fields[f] = self.copy(getattr(node, f))
+ new_node = type(node)(**new_fields)
+
+ if self.preserve_annos:
+ for k in self.preserve_annos:
+ anno.copyanno(node, new_node, k)
+ return new_node
+
+
+def copy_clean(node, preserve_annos=None):
+ """Creates a deep copy of an AST.
+
+ The copy will not include fields that are prefixed by '__', with the
+ exception of user-specified annotations.
+
+ Args:
+ node: ast.AST
+ preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
+ copy
+ Returns:
+ ast.AST
+ """
+ return CleanCopier(preserve_annos).copy(node)
+
+
+class SymbolRenamer(gast.NodeTransformer):
+ """Transformer that can rename symbols to a simple names."""
+
+ def __init__(self, name_map):
+ self.name_map = name_map
+
+ def _process(self, node):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if qn in self.name_map:
+ new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
+ # All annotations get carried over.
+ for k in anno.keys(node):
+ anno.copyanno(node, new_node, k)
+ return new_node
+ return self.generic_visit(node)
+
+ def visit_Name(self, node):
+ return self._process(node)
+
+ def visit_Attribute(self, node):
+ if anno.hasanno(node, anno.Basic.QN):
+ return self._process(node)
+ # Attributes of dynamic objects will not have a QN.
+ return self.generic_visit(node)
+
+
+def rename_symbols(node, name_map):
+ """Renames symbols in an AST. Requires qual_names annotations."""
+ renamer = SymbolRenamer(name_map)
+ if isinstance(node, list):
+ return [renamer.visit(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(renamer.visit(n) for n in node)
+ return renamer.visit(node)
+
+
+def keywords_to_dict(keywords):
+ """Converts a list of ast.keyword objects to a dict."""
+ keys = []
+ values = []
+ for kw in keywords:
+ keys.append(gast.Str(kw.arg))
+ values.append(kw.value)
+ return gast.Dict(keys=keys, values=values)
+
+
+class PatternMatcher(gast.NodeVisitor):
+ """Matches a node against a pattern represented by a node."""
+
+ def __init__(self, pattern):
+ self.pattern = pattern
+ self.pattern_stack = []
+ self.matches = True
+
+ def compare_and_visit(self, node, pattern):
+ self.pattern_stack.append(self.pattern)
+ self.pattern = pattern
+ self.generic_visit(node)
+ self.pattern = self.pattern_stack.pop()
+
+ def no_match(self):
+ self.matches = False
+ return False
+
+ def is_wildcard(self, p):
+ if isinstance(p, (list, tuple)) and len(p) == 1:
+ p, = p
+ if isinstance(p, gast.Name) and p.id == '_':
+ return True
+ if p == '_':
+ return True
+ return False
+
+ def generic_visit(self, node):
+ if not self.matches:
+ return
+
+ pattern = self.pattern
+ for f in node._fields:
+ if f.startswith('__'):
+ continue
+
+ if not hasattr(node, f):
+ if hasattr(pattern, f) and getattr(pattern, f):
+ return self.no_match()
+ else:
+ continue
+ if not hasattr(pattern, f):
+ return self.no_match()
+
+ v = getattr(node, f)
+ p = getattr(pattern, f)
+
+ if self.is_wildcard(p):
+ continue
+ if isinstance(v, (list, tuple)):
+ if not isinstance(p, (list, tuple)) or len(v) != len(p):
+ return self.no_match()
+ for v_item, p_item in zip(v, p):
+ self.compare_and_visit(v_item, p_item)
+ elif isinstance(v, (gast.AST, ast.AST)):
+ if not isinstance(v, type(p)) and not isinstance(p, type(v)):
+ return self.no_match()
+ self.compare_and_visit(v, p)
+ else:
+ # Assume everything else is a value type.
+ if v != p:
+ return self.no_match()
+
+
+def matches(node, pattern):
+ """Basic pattern matcher for AST.
+
+ The pattern may contain wildcards represented by the symbol '_'. A node
+ matches a pattern if for every node in the tree, either there is a node of
+ the same type in pattern, or a Name node with id='_'.
+
+ Args:
+ node: ast.AST
+ pattern: ast.AST
+ Returns:
+ bool
+ """
+ if isinstance(pattern, str):
+ pattern = parser.parse_expression(pattern)
+ matcher = PatternMatcher(pattern)
+ matcher.visit(node)
+ return matcher.matches
+
+
+# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+def apply_to_single_assignments(targets, values, apply_fn):
+ """Applies a function to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
+ used with the targets field of an ast.Assign node
+ values: ast.AST
+ apply_fn: Callable[[ast.AST, ast.AST], None], called with the
+ respective nodes of each single assignment
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ idx = parser.parse_expression(str(i))
+ value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
+ apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ apply_fn(target, values)
+
+
+def parallel_walk(node, other):
+ """Walks two ASTs in parallel.
+
+ The two trees must have identical structure.
+
+ Args:
+ node: Union[ast.AST, Iterable[ast.AST]]
+ other: Union[ast.AST, Iterable[ast.AST]]
+ Yields:
+ Tuple[ast.AST, ast.AST]
+ Raises:
+ ValueError: if the two trees don't have identical structure.
+ """
+ if isinstance(node, (list, tuple)):
+ node_stack = list(node)
+ else:
+ node_stack = [node]
+
+ if isinstance(other, (list, tuple)):
+ other_stack = list(other)
+ else:
+ other_stack = [other]
+
+ while node_stack and other_stack:
+ assert len(node_stack) == len(other_stack)
+ n = node_stack.pop()
+ o = other_stack.pop()
+
+ if (not isinstance(n, (ast.AST, gast.AST)) or
+ not isinstance(o, (ast.AST, gast.AST)) or
+ n.__class__.__name__ != o.__class__.__name__):
+ raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
+
+ yield n, o
+
+ for f in n._fields:
+ n_child = getattr(n, f, None)
+ o_child = getattr(o, f, None)
+ if f.startswith('__') or n_child is None or o_child is None:
+ continue
+
+ if isinstance(n_child, (list, tuple)):
+ if (not isinstance(o_child, (list, tuple)) or
+ len(n_child) != len(o_child)):
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
+ node_stack.extend(n_child)
+ other_stack.extend(o_child)
+
+ elif isinstance(n_child, (gast.AST, ast.AST)):
+ node_stack.append(n_child)
+ other_stack.append(o_child)
+
+ elif n_child != o_child:
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))