aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cond_v2_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 13:36:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 13:39:30 -0700
commitab251a0ec66a3c8b88ca467e49bfc68d18a2a8e9 (patch)
treea7c5b15fb417b2f66d52a51a55de622dcc221c73 /tensorflow/python/kernel_tests/cond_v2_test.py
parent3d3196f34173e5c6e1f9297e2fcd4c316fe903fd (diff)
Enables `If` operator lowering in cond_v2 when XLA is disabled. Lowering allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond.
However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow. Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices. Adds several tests to make sure that: - lowering is usually enabled - lowering is disabled for XLA - node colocation inside of cond_v2 branches works - explicit device placement inside of cond_v2 branches works PiperOrigin-RevId: 201049850
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py113
1 files changed, 112 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 76bbd61604..759db5d5f4 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -100,7 +101,7 @@ class NewCondTest(test.TestCase):
self.assertEqual(sess.run(out, {pred: False}), [2.0])
def _createCond(self, name):
- pred = array_ops.placeholder(dtypes.bool, name="pred")
+ pred = constant_op.constant(True, name="pred")
x = constant_op.constant(1.0, name="x")
def true_fn():
@@ -200,6 +201,65 @@ class NewCondTest(test.TestCase):
# d2[x]/dx2 = 0
self.assertEqual(false_val, [0.0])
+ def testLowering(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ out_cond = self._createCond("cond")
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond, options=run_options, run_metadata=run_metadata)
+
+ # If lowering was enabled, there should be a `Switch` node
+ switch_found = any(
+ any(node.op == "Switch" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertTrue(switch_found,
+ "A `Switch` op should exist if the graph was lowered.")
+
+ # If lowering was enabled, there should be no `If` node
+ if_found = any(
+ any(node.op == "If" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertFalse(if_found,
+ "An `If` op was found, but it should be lowered.")
+
+ def testLoweringDisabledInXLA(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Build the cond_v2 in an XLA context
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ out_cond = self._createCond("cond")
+ xla_context.Exit()
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond, options=run_options, run_metadata=run_metadata)
+
+ # Lowering disabled in XLA, there should be no `Switch` node
+ switch_found = any(
+ any(node.op == "Switch" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertFalse(
+ switch_found,
+ "A `Switch` op exists, but the graph should not be lowered.")
+
+ # Lowering disabled in XLA, there should still be an `If` node
+ if_found = any(
+ any(node.op == "If" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertTrue(
+ if_found,
+ "An `If` op was not found, but the graph should not be lowered.")
+
class CondV2CollectionTest(test.TestCase):
@@ -387,6 +447,34 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
+ def testColocateWithInCondGraphPartitioning(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(
+ graph=g,
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})
+ ) as sess:
+
+ with ops.device("/device:CPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ with ops.device("/device:CPU:1"):
+ b = constant_op.constant([2.0], name="b")
+
+ def fn():
+ with ops.colocate_with(b.op):
+ c = math_ops.add(a, a, name="c")
+ return c
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
+
+ # We expect there to be two partitions because of the
+ # colocate_with. We are only running the cond, which has a data
+ # dependency on `a` but not on `b`. So, without the colocate_with
+ # we would expect execution on just one device.
+ self.assertTrue(len(run_metadata.partition_graphs) >= 2)
+
def testDeviceBeforeCond(self):
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -421,5 +509,28 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
+ def testDeviceInCondGraphPartitioning(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(
+ graph=g,
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})
+ ) as sess:
+
+ def fn():
+ with ops.device("/device:CPU:1"):
+ c = math_ops.add(a, a, name="c")
+ return c
+
+ with ops.device("/device:CPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
+
+ self.assertTrue(len(run_metadata.partition_graphs) >= 2)
+
+
if __name__ == "__main__":
test.main()