aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-26 09:39:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 09:45:37 -0700
commiteac28534e883283977ebae4dc4dea00cdd601fbc (patch)
tree556c9d2c52acef0ab5086e3ec692e411939410ed /tensorflow/python/autograph
parent319da67052b067231d01f46692ce429da7a06f97 (diff)
Extend support for Index nodes in template expansions.
PiperOrigin-RevId: 214618421
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/pyct/templates.py2
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py12
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 1bf0515745..1af8fca599 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -123,6 +123,8 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(e)
for e in node.values:
self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Index):
+ self._check_inner_children_have_context(node.value)
elif isinstance(node, gast.Subscript):
self._check_inner_children_have_context(node.value)
self._check_inner_children_have_context(node.slice)
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 078d9a149b..3032241846 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -158,6 +158,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+ def test_replace_index(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():