diff options
author | 2018-09-26 09:39:54 -0700 | |
---|---|---|
committer | 2018-09-26 09:45:37 -0700 | |
commit | eac28534e883283977ebae4dc4dea00cdd601fbc (patch) | |
tree | 556c9d2c52acef0ab5086e3ec692e411939410ed /tensorflow/python/autograph | |
parent | 319da67052b067231d01f46692ce429da7a06f97 (diff) |
Extend support for Index nodes in template expansions.
PiperOrigin-RevId: 214618421
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/pyct/templates.py | 2 | ||||
-rw-r--r-- | tensorflow/python/autograph/pyct/templates_test.py | 12 |
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index 1bf0515745..1af8fca599 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -123,6 +123,8 @@ class ReplaceTransformer(gast.NodeTransformer): self._check_inner_children_have_context(e) for e in node.values: self._check_inner_children_have_context(e) + elif isinstance(node, gast.Index): + self._check_inner_children_have_context(node.value) elif isinstance(node, gast.Subscript): self._check_inner_children_have_context(node.value) self._check_inner_children_have_context(node.slice) diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py index 078d9a149b..3032241846 100644 --- a/tensorflow/python/autograph/pyct/templates_test.py +++ b/tensorflow/python/autograph/pyct/templates_test.py @@ -158,6 +158,18 @@ class TemplatesTest(test.TestCase): self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load) + def test_replace_index(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, foo=parser.parse_expression('foo(a[b]).bar'))[0] + function_call_arg = node.body[0].targets[0].value.args[0] + self.assertIsInstance(function_call_arg.ctx, gast.Load) + self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): |