diff options
author | 2018-09-10 14:40:21 -0700 | |
---|---|---|
committer | 2018-09-10 15:10:35 -0700 | |
commit | 6d3af1df20f611641665f63e8bb49a875823432b (patch) | |
tree | c8931753cc52512978428eff6126aa067bdc3fb7 /tensorflow/contrib/autograph | |
parent | b828f89263e054bfa7c7a808cab1506834ab906d (diff) |
Add support for list literals in template replacement values.
PiperOrigin-RevId: 212337233
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/templates.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/pyct/templates_test.py | 36 |
2 files changed, 39 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 5831d57ceb..d81c50f524 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer): if isinstance(node, gast.Attribute): self._check_inner_children_have_context(node.value) self._check_has_context(node) - elif isinstance(node, gast.Tuple): + elif isinstance(node, (gast.Tuple, gast.List)): for e in node.elts: self._check_inner_children_have_context(e) self._check_has_context(node) @@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, gast.Load()) node.ctx = ctx - elif isinstance(node, gast.Tuple): + elif isinstance(node, (gast.Tuple, gast.List)): for e in node.elts: self._set_inner_child_context(e, ctx) node.ctx = ctx @@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer): # Preserve the target context. for n in new_nodes: - if isinstance(n, gast.Tuple): + if isinstance(n, (gast.Tuple, gast.List)): for e in n.elts: self._set_inner_child_context(e, node.ctx) if isinstance(n, gast.Attribute): diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py index 77e8ff62fd..074105ea50 100644 --- a/tensorflow/contrib/autograph/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase): self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) + def test_replace_list_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) + + def test_replace_tuple_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) + + def test_replace_complex_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + function_call_arg = node.body[0].targets[0].value.args[0] + self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) + self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) + self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): |