aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/templates.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/templates.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py92
1 files changed, 59 insertions, 33 deletions
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index 9c479ebc2f..5831d57ceb 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -26,6 +26,7 @@ import textwrap
import gast
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
@@ -43,39 +44,65 @@ class ReplaceTransformer(gast.NodeTransformer):
"""
self.replacements = replacements
self.in_replacements = False
+ self.preserved_annos = {
+ anno.Basic.ORIGIN,
+ anno.Basic.SKIP_PROCESSING,
+ anno.Static.ORIG_DEFINITIONS,
+ }
+
+ def _prepare_replacement(self, replaced, key):
+ """Prepares a replacement AST that's safe to swap in for a node.
+
+ Args:
+ replaced: ast.AST, the node being replaced
+ key: Hashable, the key of the replacement AST
+ Returns:
+ ast.AST, the replacement AST
+ """
+ repl = self.replacements[key]
+
+ new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+
+ return new_nodes
def visit_Expr(self, node):
- if (isinstance(node.value, gast.Name) and
- node.value.id in self.replacements):
- return self.visit(node.value)
- self.generic_visit(node)
- return node
+ # When replacing a placeholder with an entire statement, the replacement
+ # must stand on its own and not be wrapped in an Expr.
+ new_value = self.visit(node.value)
+ if new_value is node.value:
+ return node
+ return new_value
def visit_keyword(self, node):
- if node.arg in self.replacements:
- repl = self.replacements[node.arg]
- if isinstance(repl, gast.keyword):
- return repl
- elif (isinstance(repl, (list, tuple)) and repl and
- all(isinstance(r, gast.keyword) for r in repl)):
- return repl
- # TODO(mdan): We may allow replacing with a string as well.
- # For example, if one wanted to replace foo with bar in foo=baz, then
- # we could allow changing just node arg, so that we end up with bar=baz.
- raise ValueError(
- 'a keyword argument may only be replaced by another keyword or a '
- 'non-empty list of keywords. Found: %s' % repl)
- return self.generic_visit(node)
+ if node.arg not in self.replacements:
+ return self.generic_visit(node)
+
+ repl = self._prepare_replacement(node, node.arg)
+ if isinstance(repl, gast.keyword):
+ return repl
+ elif (repl and isinstance(repl, (list, tuple)) and
+ all(isinstance(r, gast.keyword) for r in repl)):
+ return repl
+ # TODO(mdan): We may allow replacing with a string as well.
+ # For example, if one wanted to replace foo with bar in foo=baz, then
+ # we could allow changing just node arg, so that we end up with bar=baz.
+ raise ValueError(
+ 'a keyword argument may only be replaced by another keyword or a '
+ 'non-empty list of keywords. Found: %s' % repl)
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- if node.name in self.replacements:
- repl = self.replacements[node.name]
- if not isinstance(repl, (gast.Name, ast.Name)):
- raise ValueError(
- 'a function name can only be replaced by a Name node. Found: %s' %
- repl)
- node.name = repl.id
+ if node.name not in self.replacements:
+ return node
+
+ repl = self.replacements[node.name]
+ if not isinstance(repl, (gast.Name, ast.Name)):
+ raise ValueError(
+ 'a function name can only be replaced by a Name node. Found: %s' %
+ repl)
+ node.name = repl.id
return node
def _check_has_context(self, node):
@@ -113,8 +140,8 @@ class ReplaceTransformer(gast.NodeTransformer):
def _set_inner_child_context(self, node, ctx):
if isinstance(node, gast.Attribute):
- self._set_inner_child_context(node.value, ctx)
- node.ctx = gast.Load()
+ self._set_inner_child_context(node.value, gast.Load())
+ node.ctx = ctx
elif isinstance(node, gast.Tuple):
for e in node.elts:
self._set_inner_child_context(e, ctx)
@@ -148,6 +175,7 @@ class ReplaceTransformer(gast.NodeTransformer):
node = self.generic_visit(node)
if node.attr not in self.replacements:
return node
+
repl = self.replacements[node.attr]
if not isinstance(repl, gast.Name):
raise ValueError(
@@ -159,9 +187,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if node.id not in self.replacements:
return node
- new_nodes = ast_util.copy_clean(self.replacements[node.id])
- if isinstance(new_nodes, gast.AST):
- new_nodes = [new_nodes]
+ new_nodes = self._prepare_replacement(node, node.id)
# Preserve the target context.
for n in new_nodes:
@@ -182,7 +208,7 @@ class ReplaceTransformer(gast.NodeTransformer):
def _convert_to_ast(n):
- """Convert from a known data type to AST."""
+ """Converts from a known data type to AST."""
if isinstance(n, str):
# Note: the node will receive the ctx value from the template, see
# ReplaceTransformer.visit_Name.
@@ -197,7 +223,7 @@ def _convert_to_ast(n):
def replace(template, **replacements):
- """Replace placeholders in a Python template.
+ """Replaces placeholders in a Python template.
AST Name and Tuple nodes always receive the context that inferred from
the template. However, when replacing more complex nodes (that can potentially