aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-14 17:50:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 17:53:36 -0800
commitca431de46797155f296639cd978f1d2c370c89d5 (patch)
treeb023d36999e9d27ffb3b9b940901423474ff7157 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent14db6c339cc4aa0a1640dd7b86029f3a1ebad395 (diff)
Raise error if maximum_iterations argument to while_loop is defined in control flow context.
This also modifies dynamic_rnn to not provide a maximum_iterations argument if it's called within control flow. This is a hopefully temporary solution until we better support this usage. PiperOrigin-RevId: 179125216
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 5b0abaa2eb..7f2c2545dc 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -747,6 +747,19 @@ class ControlFlowTest(test.TestCase):
maximum_iterations=1)
self.assertEqual(1, r.eval())
+ def testInvalidMaximumIterationsContext(self):
+ def outer_body(i, r):
+ r = control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + 1, [0],
+ maximum_iterations=r.shape[0])
+ return i, r
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "maximum_iterations tensor cannot be declared in tf.cond or "
+ "tf.while_loop"):
+ control_flow_ops.while_loop(lambda i, r: i < 3, outer_body,
+ [0, constant_op.constant([1])])
+
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.
def testWhile_3(self):