aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/py2tf/pyct/templates.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/py2tf/pyct/templates.py')
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
index 6acc03bfce..4fadc793e6 100644
--- a/tensorflow/contrib/py2tf/pyct/templates.py
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import ast
+import copy
import gast
@@ -61,14 +62,17 @@ class ReplaceTransformer(gast.NodeTransformer):
return node
def visit_Name(self, node):
- # Note: The caller is reposnsible with making sure the replacement
- # Name nodes have the proper ctx set up.
- # TODO(mdan): Is it possible to always infer the proper context here?
if node.id in self.replacements:
# TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
- new_nodes = self.replacements[node.id]
+ new_nodes = copy.copy(self.replacements[node.id])
if isinstance(new_nodes, gast.AST):
new_nodes = [new_nodes]
+ # Preserve the target context.
+ for n in new_nodes:
+ if isinstance(n, gast.Tuple):
+ for e in n.elts:
+ e.ctx = node.ctx
+ n.ctx = node.ctx
if len(new_nodes) == 1:
new_nodes, = new_nodes
return new_nodes