aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-10 14:40:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:10:35 -0700
commit6d3af1df20f611641665f63e8bb49a875823432b (patch)
treec8931753cc52512978428eff6126aa067bdc3fb7 /tensorflow/contrib/autograph
parentb828f89263e054bfa7c7a808cab1506834ab906d (diff)
Add support for list literals in template replacement values.
PiperOrigin-RevId: 212337233
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py6
-rw-r--r--tensorflow/contrib/autograph/pyct/templates_test.py36
2 files changed, 39 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index 5831d57ceb..d81c50f524 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
self._check_has_context(node)
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._check_inner_children_have_context(e)
self._check_has_context(node)
@@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._set_inner_child_context(node.value, gast.Load())
node.ctx = ctx
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._set_inner_child_context(e, ctx)
node.ctx = ctx
@@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer):
# Preserve the target context.
for n in new_nodes:
- if isinstance(n, gast.Tuple):
+ if isinstance(n, (gast.Tuple, gast.List)):
for e in n.elts:
self._set_inner_child_context(e, node.ctx)
if isinstance(n, gast.Attribute):
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py
index 77e8ff62fd..074105ea50 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/contrib/autograph/pyct/templates_test.py
@@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
+ def test_replace_list_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_tuple_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_complex_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():