diff options
author | Dan Moldovan <mdan@google.com> | 2018-09-18 09:28:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 09:32:24 -0700 |
commit | 0c8a8289da120ee353c4fba5decb0bea9014e0a7 (patch) | |
tree | e921666e40a4ea0d84513fe6a106a8238cdda8ad /tensorflow/python/autograph | |
parent | b1ff7c2cedcc7d49d430d56655870e6d68a0c8f7 (diff) |
Extend template expansion support for arithmetic expressions.
PiperOrigin-RevId: 213462334
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/pyct/templates.py | 11 | ||||
-rw-r--r-- | tensorflow/python/autograph/pyct/templates_test.py | 12 |
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index 68c2a35fac..1bf0515745 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -109,6 +109,7 @@ class ReplaceTransformer(gast.NodeTransformer): if not node.ctx: raise ValueError('node %s is missing ctx value' % node) + # TODO(mdan): Rewrite _check and _set using a separate transformer. def _check_inner_children_have_context(self, node): if isinstance(node, gast.Attribute): self._check_inner_children_have_context(node.value) @@ -131,6 +132,11 @@ class ReplaceTransformer(gast.NodeTransformer): self._check_inner_children_have_context(node.upper) if node.step: self._check_inner_children_have_context(node.step) + elif isinstance(node, gast.BinOp): + self._check_inner_children_have_context(node.left) + self._check_inner_children_have_context(node.right) + elif isinstance(node, gast.UnaryOp): + self._check_inner_children_have_context(node.operand) elif isinstance(node, gast.Name): self._check_has_context(node) elif isinstance(node, (gast.Str, gast.Num)): @@ -166,6 +172,11 @@ class ReplaceTransformer(gast.NodeTransformer): elif isinstance(node, gast.Subscript): self._set_inner_child_context(node.value, ctx) self._check_inner_children_have_context(node.slice) + elif isinstance(node, gast.BinOp): + self._check_inner_children_have_context(node.left) + self._check_inner_children_have_context(node.right) + elif isinstance(node, gast.UnaryOp): + self._check_inner_children_have_context(node.operand) elif isinstance(node, (gast.Str, gast.Num)): pass else: diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py index 66268cfaad..078d9a149b 100644 --- a/tensorflow/python/autograph/pyct/templates_test.py +++ b/tensorflow/python/autograph/pyct/templates_test.py @@ -132,6 +132,18 @@ class TemplatesTest(test.TestCase): self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) + def test_replace_expression_context(self): + template = """ + def test_fn(foo): + foo + """ + + node = templates.replace( + template, foo=parser.parse_expression('a + 2 * b / -c'))[0] + self.assertIsInstance(node.body[0].ctx, gast.Load) + self.assertIsInstance(node.body[0].left.ctx, gast.Load) + self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load) + def test_replace_complex_context(self): template = """ def test_fn(foo): |