aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-18 09:28:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 09:32:24 -0700
commit0c8a8289da120ee353c4fba5decb0bea9014e0a7 (patch)
treee921666e40a4ea0d84513fe6a106a8238cdda8ad /tensorflow/python/autograph
parentb1ff7c2cedcc7d49d430d56655870e6d68a0c8f7 (diff)
Extend template expansion support for arithmetic expressions.
PiperOrigin-RevId: 213462334
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/pyct/templates.py11
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py12
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 68c2a35fac..1bf0515745 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -109,6 +109,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if not node.ctx:
raise ValueError('node %s is missing ctx value' % node)
+ # TODO(mdan): Rewrite _check and _set using a separate transformer.
def _check_inner_children_have_context(self, node):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
@@ -131,6 +132,11 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(node.upper)
if node.step:
self._check_inner_children_have_context(node.step)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, gast.Name):
self._check_has_context(node)
elif isinstance(node, (gast.Str, gast.Num)):
@@ -166,6 +172,11 @@ class ReplaceTransformer(gast.NodeTransformer):
elif isinstance(node, gast.Subscript):
self._set_inner_child_context(node.value, ctx)
self._check_inner_children_have_context(node.slice)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, (gast.Str, gast.Num)):
pass
else:
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 66268cfaad..078d9a149b 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -132,6 +132,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+ def test_replace_expression_context(self):
+ template = """
+ def test_fn(foo):
+ foo
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
+ self.assertIsInstance(node.body[0].ctx, gast.Load)
+ self.assertIsInstance(node.body[0].left.ctx, gast.Load)
+ self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
+
def test_replace_complex_context(self):
template = """
def test_fn(foo):