aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-10-10 08:28:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:33:21 -0700
commit131f6f8429ffa0511a3d5a6a595843d3d96ec942 (patch)
tree65eea249647113fb07f037e1eaac66103a5f513d
parentf146d586bf93b918d6f3e014b230abee49170a52 (diff)
cond_v2: raise an error if pred is a Python bool.
This is to match the existing behavior of tf.cond. PiperOrigin-RevId: 216534084
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py34
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py1
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py3
3 files changed, 23 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index a424a0f219..0e7c2f8ae6 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -670,7 +670,7 @@ class CondV2CollectionTest(test.TestCase):
y_const = constant_op.constant(ops.get_collection("y")[0])
return math_ops.add(x_const, y_const)
- cnd = cond_v2.cond_v2(True, fn, fn)
+ cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self):
@@ -705,9 +705,7 @@ class CondV2CollectionTest(test.TestCase):
z = math_ops.add(x, y)
return math_ops.mul(x, z)
- cnd = cond_v2.cond_v2(
- True, true_fn,
- false_fn)
+ cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)
self.assertEquals(cnd.eval(), 14)
read_z_collection = ops.get_collection("z")
@@ -780,10 +778,12 @@ class CondV2ContainerTest(test.TestCase):
return constant_op.constant(6.0)
with ops.container("l1"):
- cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
+ cnd_true = cond_v2.cond_v2(
+ constant_op.constant(True), true_fn, false_fn)
self.assertEquals(cnd_true.eval(), 2)
- cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
+ cnd_false = cond_v2.cond_v2(
+ constant_op.constant(False), true_fn, false_fn)
self.assertEquals(cnd_false.eval(), 6)
v4 = variables.Variable([3])
@@ -812,7 +812,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -821,7 +822,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
def testColocateWithInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -837,7 +839,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
@@ -858,7 +861,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(b.op):
c = math_ops.add(a, a, name="c")
return c
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)
+ out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
@@ -880,7 +883,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -888,7 +892,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:GPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -902,7 +907,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
@@ -921,7 +927,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.device("/device:CPU:0"):
a = constant_op.constant([2.0], name="a")
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)
+ out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a5f85b97f7..46b8b10e90 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -333,7 +333,6 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
- @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c9aa4d4889..81d9cba042 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -52,6 +52,9 @@ _gradients_impl = None
def cond_v2(pred, true_fn, false_fn, name="cond"):
"""Like tf.cond, except emits a single If op."""
+ if isinstance(pred, bool):
+ raise TypeError("pred must not be a Python bool", pred)
+
if not name:
name = "cond"