From 131f6f8429ffa0511a3d5a6a595843d3d96ec942 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 10 Oct 2018 08:28:08 -0700 Subject: cond_v2: raise an error if pred is a Python bool. This is to match the existing behavior of tf.cond. PiperOrigin-RevId: 216534084 --- tensorflow/python/kernel_tests/cond_v2_test.py | 34 +++++++++++++--------- .../kernel_tests/control_flow_ops_py_test.py | 1 - tensorflow/python/ops/cond_v2_impl.py | 3 ++ 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index a424a0f219..0e7c2f8ae6 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -670,7 +670,7 @@ class CondV2CollectionTest(test.TestCase): y_const = constant_op.constant(ops.get_collection("y")[0]) return math_ops.add(x_const, y_const) - cnd = cond_v2.cond_v2(True, fn, fn) + cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn) self.assertEquals(cnd.eval(), 7) def testCollectionTensorValueAccessInCond(self): @@ -705,9 +705,7 @@ class CondV2CollectionTest(test.TestCase): z = math_ops.add(x, y) return math_ops.mul(x, z) - cnd = cond_v2.cond_v2( - True, true_fn, - false_fn) + cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") @@ -780,10 +778,12 @@ class CondV2ContainerTest(test.TestCase): return constant_op.constant(6.0) with ops.container("l1"): - cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + cnd_true = cond_v2.cond_v2( + constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd_true.eval(), 2) - cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + cnd_false = cond_v2.cond_v2( + constant_op.constant(False), true_fn, false_fn) self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) @@ -812,7 +812,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -821,7 +822,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(a.op): with ops.colocate_with(b.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) def testColocateWithInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -837,7 +839,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) @@ -858,7 +861,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(b.op): c = math_ops.add(a, a, name="c") return c - out_cond_2 = cond_v2.cond_v2(True, fn, fn) + out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() @@ -880,7 +883,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -888,7 +892,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:GPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -902,7 +907,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) @@ -921,7 +927,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") - out_cond_2 = cond_v2.cond_v2(True, fn, fn) + out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index a5f85b97f7..46b8b10e90 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -333,7 +333,6 @@ class ControlFlowTest(test.TestCase): with self.assertRaisesOpError("has inputs from different frames"): res.eval(feed_dict={data: 1.0}) - @test_util.disable_control_flow_v2("b/113294340") def testCondBool(self): values = constant_op.constant(10) fn1 = lambda: math_ops.add(values, 1) diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index c9aa4d4889..81d9cba042 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -52,6 +52,9 @@ _gradients_impl = None def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" + if isinstance(pred, bool): + raise TypeError("pred must not be a Python bool", pred) + if not name: name = "cond" -- cgit v1.2.3