diff options
Diffstat (limited to 'tensorflow/contrib/py2tf/pyct/templates.py')
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/templates.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py index 6acc03bfce..4fadc793e6 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import ast +import copy import gast @@ -61,14 +62,17 @@ class ReplaceTransformer(gast.NodeTransformer): return node def visit_Name(self, node): - # Note: The caller is reposnsible with making sure the replacement - # Name nodes have the proper ctx set up. - # TODO(mdan): Is it possible to always infer the proper context here? if node.id in self.replacements: # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations. - new_nodes = self.replacements[node.id] + new_nodes = copy.copy(self.replacements[node.id]) if isinstance(new_nodes, gast.AST): new_nodes = [new_nodes] + # Preserve the target context. + for n in new_nodes: + if isinstance(n, gast.Tuple): + for e in n.elts: + e.ctx = node.ctx + n.ctx = node.ctx if len(new_nodes) == 1: new_nodes, = new_nodes return new_nodes |