aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-06 15:21:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 15:24:46 -0800
commit2d8206b6b5daf8f5bedd94f32c61eb2c00fd7c25 (patch)
tree3dea677f15f420196276d0d647ca9930cf5b3859 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentc2e6d554cfd7a19fa46c03e2e4eae264580b3692 (diff)
Add Python checks to prevent mixing ops from different while loops.
The executor can currently catch some errors like this by trying to reconstruct the while loop contexts by tracing the graph from enter nodes, but this doesn't catch everything and can cause hangs or other undesirable behavior. This change puts the check in Python and also provides better debugging information. In addition, this change refactors some logic from control_flow_ops.py to a new file, control_flow_util.py. This is so we can call CheckInputFromValidContext from ops.py without creating circular imports between ops.py and control_flow_ops.py. PiperOrigin-RevId: 178161679
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 1b7f9b110c..ad02a9e58c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -2622,6 +2622,122 @@ class ControlFlowTest(test.TestCase):
1)
+class ControlFlowContextCheckTest(test.TestCase):
+
+ def _getWhileTensor(self):
+ """Creates and returns a tensor from a while context."""
+ tensor = []
+
+ def body(i):
+ if not tensor:
+ tensor.append(constant_op.constant(1))
+ return i + tensor[0]
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0])
+ return tensor[0]
+
+ def _getCondTensor(self):
+ cond_tensor = []
+ def true_fn():
+ if not cond_tensor:
+ cond_tensor.append(constant_op.constant(1))
+ return cond_tensor[0]
+ control_flow_ops.cond(math_ops.less(1, 2), true_fn,
+ lambda: constant_op.constant(0))
+ return cond_tensor[0]
+
+ def testInvalidContext(self):
+ # Accessing a while loop tensor outside of control flow is illegal.
+ while_tensor = self._getWhileTensor()
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
+ "is in a while loop. See info log for more details."):
+ math_ops.add(1, while_tensor)
+
+ def testInvalidContextInCond(self):
+ # Accessing a while loop tensor in cond is illegal.
+ while_tensor = self._getWhileTensor()
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot use 'while/Const_1' as input to 'cond/Add' because "
+ "'while/Const_1' is in a while loop. See info log for more details."):
+ # TODO(skyewm): this passes if we return while_tensor directly instead
+ # of using it as input to another op.
+ control_flow_ops.cond(math_ops.less(1, 2),
+ lambda: math_ops.add(1, while_tensor),
+ lambda: constant_op.constant(0))
+
+ def testInvalidContextInWhile(self):
+ # Accessing a while loop tensor in a different while loop is illegal.
+ while_tensor = self._getWhileTensor()
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot use 'while_1/Add' as input to 'while/Const_1' because they are "
+ "in different while loops. See info log for more details."):
+ control_flow_ops.while_loop(lambda i: i < 10,
+ lambda x: math_ops.add(1, while_tensor), [0])
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot use 'while_2/NextIteration' as input to 'while/Const_1' "
+ "because they are in different while loops. See info log for more "
+ "details."):
+ control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
+
+ def testValidCondContext(self):
+ # Accessing a tensor from a cond context is OK (although dangerous).
+ cond_tensor = self._getCondTensor()
+ math_ops.add(1, cond_tensor)
+
+ def testValidCondContextBranches(self):
+ # Accessing a tensor from a cond context from the other branch's cond
+ # context is OK (although dangerous).
+ cond_tensor = []
+ def branch_fn():
+ if not cond_tensor:
+ cond_tensor.append(constant_op.constant(1))
+ return cond_tensor[0]
+
+ control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
+
+ def testValidWhileContext(self):
+ # Accessing a tensor in a nested while is OK.
+ def body(_):
+ c = constant_op.constant(1)
+ return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
+
+ control_flow_ops.while_loop(lambda i: i < 5, body, [0])
+
+ def testValidNestedContexts(self):
+ # Accessing a tensor from a cond context in a while context, all inside an
+ # outer while context, is OK.
+ def body(_):
+ cond_tensor = self._getCondTensor()
+ # Create another cond containing the while loop for good measure
+ return control_flow_ops.cond(
+ math_ops.less(1, 2),
+ lambda: control_flow_ops.while_loop(lambda i: i < 3,
+ lambda i: i + cond_tensor, [0]),
+ lambda: constant_op.constant(0))
+
+ control_flow_ops.while_loop(lambda i: i < 5, body, [0])
+
+ def testInvalidNestedContexts(self):
+ # Accessing a tensor from a while context in a different while context, all
+ # inside a cond context, is illegal.
+ def true_fn():
+ while_tensor = self._getWhileTensor()
+ return control_flow_ops.while_loop(lambda i: i < 3,
+ lambda i: i + while_tensor, [0])
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot use 'cond/while_1/add' as input to 'cond/while/Const_1' because"
+ " they are in different while loops. See info log for more details."):
+ control_flow_ops.cond(math_ops.less(1, 2), true_fn,
+ lambda: constant_op.constant(0))
+
+
class TupleTest(test.TestCase):
def testTensors(self):