aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/ast_util.py
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-11 16:20:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 16:32:19 -0700
commit668c079f4e6020131978b7a812c3b92eea9c47b9 (patch)
tree269836fd98f37b3a099e6b4cceeb3256416705fa /tensorflow/python/autograph/pyct/ast_util.py
parentefd9e0d073a6632f7632f7fe43ae4364cc2c834b (diff)
Move AutoGraph to core. This CL moves the entirety of the code base, keeping the frontend autograph module in contrib for backward compatibility. Certain files, like notebooks and the readme file may be referenced from the outside, so a copy of those is kept as well. In addition, the notebooks subdirectory of examples is also kept in contrib because the extension the build file relies on is not available in the PIP package.
PiperOrigin-RevId: 212543067
Diffstat (limited to 'tensorflow/python/autograph/pyct/ast_util.py')
-rw-r--r--tensorflow/python/autograph/pyct/ast_util.py313
1 files changed, 313 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
new file mode 100644
index 0000000000..7df3b8858c
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -0,0 +1,313 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST manipulation utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+
+
+class CleanCopier(object):
+ """NodeTransformer-like visitor that copies an AST."""
+
+ def __init__(self, preserve_annos):
+ super(CleanCopier, self).__init__()
+ self.preserve_annos = preserve_annos
+
+ def copy(self, node):
+ """Returns a deep copy of node (excluding some fields, see copy_clean)."""
+
+ if isinstance(node, list):
+ return [self.copy(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(self.copy(n) for n in node)
+ elif not isinstance(node, (gast.AST, ast.AST)):
+ # Assuming everything that's not an AST, list or tuple is a value type
+ # and may simply be assigned.
+ return node
+
+ assert isinstance(node, (gast.AST, ast.AST))
+
+ new_fields = {}
+ for f in node._fields:
+ if not f.startswith('__') and hasattr(node, f):
+ new_fields[f] = self.copy(getattr(node, f))
+ new_node = type(node)(**new_fields)
+
+ if self.preserve_annos:
+ for k in self.preserve_annos:
+ anno.copyanno(node, new_node, k)
+ return new_node
+
+
+def copy_clean(node, preserve_annos=None):
+ """Creates a deep copy of an AST.
+
+ The copy will not include fields that are prefixed by '__', with the
+ exception of user-specified annotations.
+
+ Args:
+ node: ast.AST
+ preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
+ copy
+ Returns:
+ ast.AST
+ """
+ return CleanCopier(preserve_annos).copy(node)
+
+
+class SymbolRenamer(gast.NodeTransformer):
+ """Transformer that can rename symbols to a simple names."""
+
+ def __init__(self, name_map):
+ self.name_map = name_map
+
+ def _process(self, node):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if qn in self.name_map:
+ new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
+ # All annotations get carried over.
+ for k in anno.keys(node):
+ anno.copyanno(node, new_node, k)
+ return new_node
+ return self.generic_visit(node)
+
+ def visit_Name(self, node):
+ return self._process(node)
+
+ def visit_Attribute(self, node):
+ if anno.hasanno(node, anno.Basic.QN):
+ return self._process(node)
+ # Attributes of dynamic objects will not have a QN.
+ return self.generic_visit(node)
+
+
+def rename_symbols(node, name_map):
+ """Renames symbols in an AST. Requires qual_names annotations."""
+ renamer = SymbolRenamer(name_map)
+ if isinstance(node, list):
+ return [renamer.visit(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(renamer.visit(n) for n in node)
+ return renamer.visit(node)
+
+
+def keywords_to_dict(keywords):
+ """Converts a list of ast.keyword objects to a dict."""
+ keys = []
+ values = []
+ for kw in keywords:
+ keys.append(gast.Str(kw.arg))
+ values.append(kw.value)
+ return gast.Dict(keys=keys, values=values)
+
+
+class PatternMatcher(gast.NodeVisitor):
+ """Matches a node against a pattern represented by a node."""
+
+ def __init__(self, pattern):
+ self.pattern = pattern
+ self.pattern_stack = []
+ self.matches = True
+
+ def compare_and_visit(self, node, pattern):
+ self.pattern_stack.append(self.pattern)
+ self.pattern = pattern
+ self.generic_visit(node)
+ self.pattern = self.pattern_stack.pop()
+
+ def no_match(self):
+ self.matches = False
+ return False
+
+ def is_wildcard(self, p):
+ if isinstance(p, (list, tuple)) and len(p) == 1:
+ p, = p
+ if isinstance(p, gast.Name) and p.id == '_':
+ return True
+ if p == '_':
+ return True
+ return False
+
+ def generic_visit(self, node):
+ if not self.matches:
+ return
+
+ pattern = self.pattern
+ for f in node._fields:
+ if f.startswith('__'):
+ continue
+
+ if not hasattr(node, f):
+ if hasattr(pattern, f) and getattr(pattern, f):
+ return self.no_match()
+ else:
+ continue
+ if not hasattr(pattern, f):
+ return self.no_match()
+
+ v = getattr(node, f)
+ p = getattr(pattern, f)
+
+ if self.is_wildcard(p):
+ continue
+ if isinstance(v, (list, tuple)):
+ if not isinstance(p, (list, tuple)) or len(v) != len(p):
+ return self.no_match()
+ for v_item, p_item in zip(v, p):
+ self.compare_and_visit(v_item, p_item)
+ elif isinstance(v, (gast.AST, ast.AST)):
+ if not isinstance(v, type(p)) and not isinstance(p, type(v)):
+ return self.no_match()
+ self.compare_and_visit(v, p)
+ else:
+ # Assume everything else is a value type.
+ if v != p:
+ return self.no_match()
+
+
+def matches(node, pattern):
+ """Basic pattern matcher for AST.
+
+ The pattern may contain wildcards represented by the symbol '_'. A node
+ matches a pattern if for every node in the tree, either there is a node of
+ the same type in pattern, or a Name node with id='_'.
+
+ Args:
+ node: ast.AST
+ pattern: ast.AST
+ Returns:
+ bool
+ """
+ if isinstance(pattern, str):
+ pattern = parser.parse_expression(pattern)
+ matcher = PatternMatcher(pattern)
+ matcher.visit(node)
+ return matcher.matches
+
+
+# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+def apply_to_single_assignments(targets, values, apply_fn):
+ """Applies a function to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
+ used with the targets field of an ast.Assign node
+ values: ast.AST
+ apply_fn: Callable[[ast.AST, ast.AST], None], called with the
+ respective nodes of each single assignment
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ idx = parser.parse_expression(str(i))
+ value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
+ apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ apply_fn(target, values)
+
+
+def parallel_walk(node, other):
+ """Walks two ASTs in parallel.
+
+ The two trees must have identical structure.
+
+ Args:
+ node: Union[ast.AST, Iterable[ast.AST]]
+ other: Union[ast.AST, Iterable[ast.AST]]
+ Yields:
+ Tuple[ast.AST, ast.AST]
+ Raises:
+ ValueError: if the two trees don't have identical structure.
+ """
+ if isinstance(node, (list, tuple)):
+ node_stack = list(node)
+ else:
+ node_stack = [node]
+
+ if isinstance(other, (list, tuple)):
+ other_stack = list(other)
+ else:
+ other_stack = [other]
+
+ while node_stack and other_stack:
+ assert len(node_stack) == len(other_stack)
+ n = node_stack.pop()
+ o = other_stack.pop()
+
+ if (not isinstance(n, (ast.AST, gast.AST)) or
+ not isinstance(o, (ast.AST, gast.AST)) or
+ n.__class__.__name__ != o.__class__.__name__):
+ raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
+
+ yield n, o
+
+ for f in n._fields:
+ n_child = getattr(n, f, None)
+ o_child = getattr(o, f, None)
+ if f.startswith('__') or n_child is None or o_child is None:
+ continue
+
+ if isinstance(n_child, (list, tuple)):
+ if (not isinstance(o_child, (list, tuple)) or
+ len(n_child) != len(o_child)):
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
+ node_stack.extend(n_child)
+ other_stack.extend(o_child)
+
+ elif isinstance(n_child, (gast.AST, ast.AST)):
+ node_stack.append(n_child)
+ other_stack.append(o_child)
+
+ elif n_child != o_child:
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))