diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 09:10:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 09:18:56 -0700 |
commit | 8e5c118ce835e0b8625ef073e2f4d978c70498ae (patch) | |
tree | 6cb261f729c84b1731c02138045949d0a9962dbe /tensorflow/contrib/autograph | |
parent | 8987d1cfd3c17eab4e28da376fdc718f53d82e19 (diff) |
While loop dispatch depends only on whether variables directly referenced in the condition are tensors.
This fixes a bug where a variable in an inner loop could be referenced before creation. These variables would be used in the AG while_stmt to determine whether to dispatch to tf.while_loop or run the Python loop.
PiperOrigin-RevId: 210550604
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/converters/control_flow.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/converters/control_flow_test.py | 18 |
2 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 8d314250a0..3530fbb2ec 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -217,7 +217,7 @@ class ControlFlowTransformer(converter.Base): cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() - for s in cond_scope.referenced: + for s in cond_scope.used: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) @@ -250,6 +250,7 @@ class ControlFlowTransformer(converter.Base): node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) + # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 2a6f3cb395..1d04ba3ba6 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -48,6 +48,24 @@ class ControlFlowTest(converter_testing.TestCase): self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) + def test_while_nested(self): + + def test_fn(n): + i = 0 + j = 0 + s = 0 + while i < n: + while j < i: + j += 3 + u = i + j # 'u' is not defined within the inner loop + s += u + i += 1 + j = 0 + return s, i, j, n + + self.assertTransformedResult(test_fn, constant_op.constant(5), + (25, 5, 0, 5)) + def test_while_single_output(self): def test_fn(n): |