diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-18 13:36:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 13:39:30 -0700 |
commit | ab251a0ec66a3c8b88ca467e49bfc68d18a2a8e9 (patch) | |
tree | a7c5b15fb417b2f66d52a51a55de622dcc221c73 /tensorflow/python/kernel_tests/cond_v2_test.py | |
parent | 3d3196f34173e5c6e1f9297e2fcd4c316fe903fd (diff) |
Enables `If` operator lowering in cond_v2 when XLA is disabled. Lowering allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond.
However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow.
Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices.
Adds several tests to make sure that:
- lowering is usually enabled
- lowering is disabled for XLA
- node colocation inside of cond_v2 branches works
- explicit device placement inside of cond_v2 branches works
PiperOrigin-RevId: 201049850
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 113 |
1 files changed, 112 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 76bbd61604..759db5d5f4 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -100,7 +101,7 @@ class NewCondTest(test.TestCase): self.assertEqual(sess.run(out, {pred: False}), [2.0]) def _createCond(self, name): - pred = array_ops.placeholder(dtypes.bool, name="pred") + pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): @@ -200,6 +201,65 @@ class NewCondTest(test.TestCase): # d2[x]/dx2 = 0 self.assertEqual(false_val, [0.0]) + def testLowering(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + out_cond = self._createCond("cond") + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # If lowering was enabled, there should be a `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue(switch_found, + "A `Switch` op should exist if the graph was lowered.") + + # If lowering was enabled, there should be no `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse(if_found, + "An `If` op was found, but it should be lowered.") + + def testLoweringDisabledInXLA(self): + with self.test_session(graph=ops.Graph()) as sess: + # Build the cond_v2 in an XLA context + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + out_cond = self._createCond("cond") + xla_context.Exit() + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # Lowering disabled in XLA, there should be no `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse( + switch_found, + "A `Switch` op exists, but the graph should not be lowered.") + + # Lowering disabled in XLA, there should still be an `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue( + if_found, + "An `If` op was not found, but the graph should not be lowered.") + class CondV2CollectionTest(test.TestCase): @@ -387,6 +447,34 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + def testColocateWithInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.device("/device:CPU:1"): + b = constant_op.constant([2.0], name="b") + + def fn(): + with ops.colocate_with(b.op): + c = math_ops.add(a, a, name="c") + return c + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + # We expect there to be two partitions because of the + # colocate_with. We are only running the cond, which has a data + # dependency on `a` but not on `b`. So, without the colocate_with + # we would expect execution on just one device. + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + def testDeviceBeforeCond(self): with ops.Graph().as_default() as g: with self.test_session(graph=g): @@ -421,5 +509,28 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) + def testDeviceInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + def fn(): + with ops.device("/device:CPU:1"): + c = math_ops.add(a, a, name="c") + return c + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + if __name__ == "__main__": test.main() |