aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 09:10:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 09:18:56 -0700
commit8e5c118ce835e0b8625ef073e2f4d978c70498ae (patch)
tree6cb261f729c84b1731c02138045949d0a9962dbe /tensorflow/contrib/autograph
parent8987d1cfd3c17eab4e28da376fdc718f53d82e19 (diff)
While loop dispatch depends only on whether variables directly referenced in the condition are tensors.
This fixes a bug where a variable in an inner loop could be referenced before creation. These variables would be used in the AG while_stmt to determine whether to dispatch to tf.while_loop or run the Python loop. PiperOrigin-RevId: 210550604
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py3
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py18
2 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 8d314250a0..3530fbb2ec 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -217,7 +217,7 @@ class ControlFlowTransformer(converter.Base):
cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
cond_closure = set()
- for s in cond_scope.referenced:
+ for s in cond_scope.used:
for root in s.support_set:
if root not in body_scope.created:
cond_closure.add(root)
@@ -250,6 +250,7 @@ class ControlFlowTransformer(converter.Base):
node_body = ast_util.rename_symbols(node.body, ssf_map)
test = ast_util.rename_symbols(node.test, ssf_map)
+ # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
template = """
def test_name(state_ssf):
return test
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 2a6f3cb395..1d04ba3ba6 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -48,6 +48,24 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
+ def test_while_nested(self):
+
+ def test_fn(n):
+ i = 0
+ j = 0
+ s = 0
+ while i < n:
+ while j < i:
+ j += 3
+ u = i + j # 'u' is not defined within the inner loop
+ s += u
+ i += 1
+ j = 0
+ return s, i, j, n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5),
+ (25, 5, 0, 5))
+
def test_while_single_output(self):
def test_fn(n):