diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-01-16 16:17:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-16 16:23:05 -0800 |
commit | c9096fd166a9d7fdb62c6cb747a74edb73630b0c (patch) | |
tree | 4c9cd12946750b03b2dc850916fe2e16db3d955e /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 1de8ca3edb22c232b6cd4a87076bd5e0a7f6b86f (diff) |
[TF] Fix XLA Control Flow gradient stacks max_size creation.
Stack creation uses tf.while_loop's maximum_iterations iff the while_loop
was created inside an XLA/TPU context. Added several error checks to ensure
this provides useful error messages if the limited use case is not supported.
PiperOrigin-RevId: 182128135
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 162 |
1 files changed, 153 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 7f2c2545dc..6e18ed132c 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -747,18 +747,162 @@ class ControlFlowTest(test.TestCase): maximum_iterations=1) self.assertEqual(1, r.eval()) - def testInvalidMaximumIterationsContext(self): - def outer_body(i, r): - r = control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + 1, [0], - maximum_iterations=r.shape[0]) - return i, r + def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): + v = constant_op.constant(1.0) + def training_loop_with_gradient(i): + out = control_flow_ops.while_loop( + lambda i_, _: i_ < 3, + lambda i_, j: [i_ + 1, j * v], + [0, 1.0], + maximum_iterations=i) + g = gradients_impl.gradients(out, v) + with ops.control_dependencies(g): + return i + 1 + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + # Create training loop, ensure we can call gradient() of + # while_loop inside the training loop. + loop = control_flow_ops.while_loop( + lambda i: i < 3, training_loop_with_gradient, [0]) + xla_context.Exit() + + loop_execute = array_ops.identity(loop) # Because loop is not fetchable. + + # Should execute without issue. + self.assertEqual(3, self.evaluate(loop_execute)) + + def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): + v = constant_op.constant(1.0) + def inner_body(i, x): + out = control_flow_ops.while_loop( + lambda i, _: i < 3, + lambda i, j: [i + 1, j * v], + [0, x], + maximum_iterations=i) + return out + + def create_while_loop(maximum_iterations=None): + return control_flow_ops.while_loop( + lambda i, _: i < 3, inner_body, [0, 1.0], + maximum_iterations=maximum_iterations) + + loop_no_xla = create_while_loop(maximum_iterations=5) + # maximum_iterations is fine outside of an XLA scope + gs = gradients_impl.gradients(loop_no_xla, v) + self.evaluate(gs) # This should execute without error. + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop_no_maxiter = create_while_loop() + loop_with_maxiter = create_while_loop(maximum_iterations=2) + xla_context.Exit() with self.assertRaisesRegexp( ValueError, - "maximum_iterations tensor cannot be declared in tf.cond or " - "tf.while_loop"): - control_flow_ops.while_loop(lambda i, r: i < 3, outer_body, - [0, constant_op.constant([1])]) + r"Cannot create a gradient accumulator for tensor '.+' inside " + r"XLA while_loop because maximum_iterations was not passed to " + r"the tf.while_loop call \('.+'\)."): + _ = gradients_impl.gradients(loop_no_maxiter, v) + + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.+' for while_loop context " + r"'.+' must be statically known \(e.g. a constant value or known " + r"shape dimension\), or be defined at or outside the while loop " + r"context '.*' \(currently defined in '.*'\)"): + _ = gradients_impl.gradients(loop_with_maxiter, v) + + def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): + v = constant_op.constant(1.0) + + def create_while_loop(): + max_iter_holder = [] + def create_mi(): + max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) + return 1.0 + _ = control_flow_ops.cond(constant_op.constant(True), + create_mi, create_mi) + + return control_flow_ops.while_loop( + lambda i, _: i < 3, lambda i, x: (i + 1, v * x), (0, 1.0), + maximum_iterations=max_iter_holder[0]) + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop = create_while_loop() + xla_context.Exit() + + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " + r"while_loop context '.+' must be statically known \(e.g. a constant " + r"value or known shape dimension\), or be defined at or outside the " + r"while loop context '' \(currently defined in 'cond/.+'\)"): + _ = gradients_impl.gradients(loop, v) + + def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): + v = constant_op.constant(1.0) + + p = array_ops.placeholder(dtype=dtypes.int32) + + def mid_body_builder(iterations): + def mid_body(i, x): + r = control_flow_ops.while_loop( + lambda *_: True, + lambda i, x: (i + 1, v * x), + (0, x), + maximum_iterations=iterations, name="inner") + return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) + return mid_body + + def outer_body(i, x): + iterations = array_ops.size(p, name="iterations") + return ( + i + 1, + x + control_flow_ops.while_loop( + lambda *_: True, mid_body_builder(iterations), (0, x), + maximum_iterations=iterations, name="mid")[1]) + + def create_while_loop(): + with ops.device("/cpu:0"): + r = control_flow_ops.while_loop( + lambda *_: True, outer_body, (0, 1.0), + maximum_iterations=5, name="outer") + return array_ops.identity(r[1]) + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + final_with_xla_context = create_while_loop() + xla_context.Exit() + + final_without_xla_context = create_while_loop() + + with self.test_session(use_gpu=False) as sess: + opts = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + + final_value_without_xla_context = sess.run( + final_without_xla_context, + feed_dict={p: [0, 0, 0]}) + + final_value_with_xla_context = sess.run( + final_with_xla_context, + feed_dict={p: [0, 0, 0]}, + options=opts, run_metadata=run_metadata) + + node_stats = run_metadata.step_stats.dev_stats[0].node_stats + stack_push_count = len( + [x for x in node_stats if x.node_name.endswith("StackPushV2")]) + # Pushes to the stack = product of maximum_iterations values; + # the last two "3"s comes from size(p), when p == [0, 0, 0]. + self.assertEqual(stack_push_count, 5 * 3 * 3) + + self.assertAllClose( + final_value_with_xla_context, final_value_without_xla_context) # Have more than 10 parallel iterations and hence exercise k-bound # most of the time. |