aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-01-16 16:17:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 16:23:05 -0800
commitc9096fd166a9d7fdb62c6cb747a74edb73630b0c (patch)
tree4c9cd12946750b03b2dc850916fe2e16db3d955e /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent1de8ca3edb22c232b6cd4a87076bd5e0a7f6b86f (diff)
[TF] Fix XLA Control Flow gradient stacks max_size creation.
Stack creation uses tf.while_loop's maximum_iterations iff the while_loop was created inside an XLA/TPU context. Added several error checks to ensure this provides useful error messages if the limited use case is not supported. PiperOrigin-RevId: 182128135
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py162
1 files changed, 153 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 7f2c2545dc..6e18ed132c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -747,18 +747,162 @@ class ControlFlowTest(test.TestCase):
maximum_iterations=1)
self.assertEqual(1, r.eval())
- def testInvalidMaximumIterationsContext(self):
- def outer_body(i, r):
- r = control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + 1, [0],
- maximum_iterations=r.shape[0])
- return i, r
+ def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
+ v = constant_op.constant(1.0)
+ def training_loop_with_gradient(i):
+ out = control_flow_ops.while_loop(
+ lambda i_, _: i_ < 3,
+ lambda i_, j: [i_ + 1, j * v],
+ [0, 1.0],
+ maximum_iterations=i)
+ g = gradients_impl.gradients(out, v)
+ with ops.control_dependencies(g):
+ return i + 1
+
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ # Create training loop, ensure we can call gradient() of
+ # while_loop inside the training loop.
+ loop = control_flow_ops.while_loop(
+ lambda i: i < 3, training_loop_with_gradient, [0])
+ xla_context.Exit()
+
+ loop_execute = array_ops.identity(loop) # Because loop is not fetchable.
+
+ # Should execute without issue.
+ self.assertEqual(3, self.evaluate(loop_execute))
+
+ def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
+ v = constant_op.constant(1.0)
+ def inner_body(i, x):
+ out = control_flow_ops.while_loop(
+ lambda i, _: i < 3,
+ lambda i, j: [i + 1, j * v],
+ [0, x],
+ maximum_iterations=i)
+ return out
+
+ def create_while_loop(maximum_iterations=None):
+ return control_flow_ops.while_loop(
+ lambda i, _: i < 3, inner_body, [0, 1.0],
+ maximum_iterations=maximum_iterations)
+
+ loop_no_xla = create_while_loop(maximum_iterations=5)
+ # maximum_iterations is fine outside of an XLA scope
+ gs = gradients_impl.gradients(loop_no_xla, v)
+ self.evaluate(gs) # This should execute without error.
+
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ loop_no_maxiter = create_while_loop()
+ loop_with_maxiter = create_while_loop(maximum_iterations=2)
+ xla_context.Exit()
with self.assertRaisesRegexp(
ValueError,
- "maximum_iterations tensor cannot be declared in tf.cond or "
- "tf.while_loop"):
- control_flow_ops.while_loop(lambda i, r: i < 3, outer_body,
- [0, constant_op.constant([1])])
+ r"Cannot create a gradient accumulator for tensor '.+' inside "
+ r"XLA while_loop because maximum_iterations was not passed to "
+ r"the tf.while_loop call \('.+'\)."):
+ _ = gradients_impl.gradients(loop_no_maxiter, v)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
+ r"while_loop. maximum_iterations tensor '.+' for while_loop context "
+ r"'.+' must be statically known \(e.g. a constant value or known "
+ r"shape dimension\), or be defined at or outside the while loop "
+ r"context '.*' \(currently defined in '.*'\)"):
+ _ = gradients_impl.gradients(loop_with_maxiter, v)
+
+ def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
+ v = constant_op.constant(1.0)
+
+ def create_while_loop():
+ max_iter_holder = []
+ def create_mi():
+ max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
+ return 1.0
+ _ = control_flow_ops.cond(constant_op.constant(True),
+ create_mi, create_mi)
+
+ return control_flow_ops.while_loop(
+ lambda i, _: i < 3, lambda i, x: (i + 1, v * x), (0, 1.0),
+ maximum_iterations=max_iter_holder[0])
+
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ loop = create_while_loop()
+ xla_context.Exit()
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
+ r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
+ r"while_loop context '.+' must be statically known \(e.g. a constant "
+ r"value or known shape dimension\), or be defined at or outside the "
+ r"while loop context '' \(currently defined in 'cond/.+'\)"):
+ _ = gradients_impl.gradients(loop, v)
+
+ def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
+ v = constant_op.constant(1.0)
+
+ p = array_ops.placeholder(dtype=dtypes.int32)
+
+ def mid_body_builder(iterations):
+ def mid_body(i, x):
+ r = control_flow_ops.while_loop(
+ lambda *_: True,
+ lambda i, x: (i + 1, v * x),
+ (0, x),
+ maximum_iterations=iterations, name="inner")
+ return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
+ return mid_body
+
+ def outer_body(i, x):
+ iterations = array_ops.size(p, name="iterations")
+ return (
+ i + 1,
+ x + control_flow_ops.while_loop(
+ lambda *_: True, mid_body_builder(iterations), (0, x),
+ maximum_iterations=iterations, name="mid")[1])
+
+ def create_while_loop():
+ with ops.device("/cpu:0"):
+ r = control_flow_ops.while_loop(
+ lambda *_: True, outer_body, (0, 1.0),
+ maximum_iterations=5, name="outer")
+ return array_ops.identity(r[1])
+
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ final_with_xla_context = create_while_loop()
+ xla_context.Exit()
+
+ final_without_xla_context = create_while_loop()
+
+ with self.test_session(use_gpu=False) as sess:
+ opts = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ final_value_without_xla_context = sess.run(
+ final_without_xla_context,
+ feed_dict={p: [0, 0, 0]})
+
+ final_value_with_xla_context = sess.run(
+ final_with_xla_context,
+ feed_dict={p: [0, 0, 0]},
+ options=opts, run_metadata=run_metadata)
+
+ node_stats = run_metadata.step_stats.dev_stats[0].node_stats
+ stack_push_count = len(
+ [x for x in node_stats if x.node_name.endswith("StackPushV2")])
+ # Pushes to the stack = product of maximum_iterations values;
+ # the last two "3"s comes from size(p), when p == [0, 0, 0].
+ self.assertEqual(stack_push_count, 5 * 3 * 3)
+
+ self.assertAllClose(
+ final_value_with_xla_context, final_value_without_xla_context)
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.