aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-07-31 04:28:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 04:31:50 -0700
commit3bec2640dcbd251f4eb2517d9ae7d8909886375d (patch)
treefbdacf68f0bfa103664e33a3d2740ca3595676a5 /tensorflow/contrib/autograph
parent3a1df26a25aa1b5fc7b897c0187c982c07f98368 (diff)
Use TF constants for the break/continue control variables, to ensure control dependencies get created correctly. This renders break cond continue incompatible with Python inputs, but that's an extremely very unlikely use case.
PiperOrigin-RevId: 206738877
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py6
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements_test.py38
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements.py4
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements_test.py32
4 files changed, 47 insertions, 33 deletions
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 2a60750bda..180779670d 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
- var_name = True
+ var_name = tf.constant(True)
continue
"""
return templates.replace(template, var_name=var_name)
@@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
guarded_orelse = self._guard_if_present(node.orelse, break_var)
template = """
- var_name = False
+ var_name = tf.constant(False)
while test and not var_name:
body
else:
@@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
# the control variable is marked as used.
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
- var_name = False
+ var_name = tf.constant(False)
for target in iter_:
(var_name,)
body
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py
index c26ca2946c..fcae7d68c0 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/break_statements_test.py
@@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import break_statements
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
- with self.converted(test_fn, break_statements, {}) as result:
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self):
@@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- with self.converted(test_fn, break_statements, {}) as result:
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
@@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 11)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 11)
def test_nested_loops(self):
@@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 2)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 5)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 5)
def test_loop_orelse(self):
@@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 2)
- self.assertTransformedEquivalent(test_fn, 3)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py
index 958bde0a58..0476e97c15 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements.py
@@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
def visit_Continue(self, node):
self.set_local(CONTINUE_USED, True)
template = """
- var_name = True
+ var_name = tf.constant(True)
"""
return templates.replace(
template, var_name=self.get_local(CONTROL_VAR_NAME))
@@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
if self.get_local(CONTINUE_USED, False):
template = """
- var_name = False
+ var_name = tf.constant(False)
"""
control_var_init = templates.replace(template, var_name=continue_var)
nodes = control_var_init + nodes
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py
index 3a7c7d1486..37c15211b4 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py
@@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
- with self.converted(test_fn, continue_statements, {}) as result:
+ with self.converted(test_fn, continue_statements, {},
+ constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self):
@@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
- self.assertTransformedEquivalent(test_fn, [])
- self.assertTransformedEquivalent(test_fn, [1])
- self.assertTransformedEquivalent(test_fn, [2])
- self.assertTransformedEquivalent(test_fn, [1, 2, 3])
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1])
+ self.assertTransformedEquivalent(test_fn, [2])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
def test_nested(self):
@@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
- self.assertTransformedEquivalent(test_fn, 0)
- self.assertTransformedEquivalent(test_fn, 1)
- self.assertTransformedEquivalent(test_fn, 3)
- self.assertTransformedEquivalent(test_fn, 4)
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
if __name__ == '__main__':