diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 09:16:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 09:22:22 -0700 |
commit | cffdccdc642c1fe852e61cb3236aa00ee53c92bf (patch) | |
tree | 35d8618307782e49a7e2f0d6d7eb3860515e6e52 /tensorflow/contrib/autograph | |
parent | ce40173f61c79af05dcd0e0330cdb80bb179585d (diff) |
Raise AutoGraph error if variable is created inside a loop body.
PiperOrigin-RevId: 209775953
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/converters/control_flow.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/converters/control_flow_test.py | 20 |
2 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 5a5a6ad63a..f7dd3183b0 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -95,6 +95,18 @@ class ControlFlowTransformer(converter.Base): return 'no variables' return ', '.join(map(str, symbol_set)) + def _validate_no_live_vars_created(self, node): + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + live_vars_created_in_body = live_vars_out & body_scope.created + if live_vars_created_in_body: + raise ValueError( + 'The following variables are created inside the loop and used later:' + '\n%s\n' + 'Variables must be declared outside loops because loops may not' + ' necessarily execute.' % self._fmt_symbol_list( + live_vars_created_in_body)) + def visit_If(self, node): node = self.generic_visit(node) @@ -197,6 +209,8 @@ class ControlFlowTransformer(converter.Base): def visit_While(self, node): self.generic_visit(node) + self._validate_no_live_vars_created(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced @@ -262,6 +276,8 @@ class ControlFlowTransformer(converter.Base): def visit_For(self, node): self.generic_visit(node) + self._validate_no_live_vars_created(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 6cb907f69a..02bc00dbc8 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -57,6 +57,17 @@ class ControlFlowTest(converter_testing.TestCase): self.assertTransformedResult(test_fn, constant_op.constant(5), 0) + def test_while_variable_defined_in_body(self): + def bad_while_loop(n): + while n > 0: + n -= 1 + s = n + return s + + node, ctx = self.prepare(bad_while_loop, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) + def test_if_basic(self): def test_fn(n): @@ -196,6 +207,15 @@ class ControlFlowTest(converter_testing.TestCase): self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1) + def test_for_variable_defined_in_body(self): + def bad_for_loop(n): + for i in range(n): + s = i + return s + + node, ctx = self.prepare(bad_for_loop, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) if __name__ == '__main__': test.main() |