aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cond_v2_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index a424a0f219..0e7c2f8ae6 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -670,7 +670,7 @@ class CondV2CollectionTest(test.TestCase):
y_const = constant_op.constant(ops.get_collection("y")[0])
return math_ops.add(x_const, y_const)
- cnd = cond_v2.cond_v2(True, fn, fn)
+ cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self):
@@ -705,9 +705,7 @@ class CondV2CollectionTest(test.TestCase):
z = math_ops.add(x, y)
return math_ops.mul(x, z)
- cnd = cond_v2.cond_v2(
- True, true_fn,
- false_fn)
+ cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)
self.assertEquals(cnd.eval(), 14)
read_z_collection = ops.get_collection("z")
@@ -780,10 +778,12 @@ class CondV2ContainerTest(test.TestCase):
return constant_op.constant(6.0)
with ops.container("l1"):
- cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
+ cnd_true = cond_v2.cond_v2(
+ constant_op.constant(True), true_fn, false_fn)
self.assertEquals(cnd_true.eval(), 2)
- cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
+ cnd_false = cond_v2.cond_v2(
+ constant_op.constant(False), true_fn, false_fn)
self.assertEquals(cnd_false.eval(), 6)
v4 = variables.Variable([3])
@@ -812,7 +812,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -821,7 +822,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
def testColocateWithInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -837,7 +839,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
@@ -858,7 +861,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(b.op):
c = math_ops.add(a, a, name="c")
return c
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)
+ out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
@@ -880,7 +883,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -888,7 +892,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:GPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -902,7 +907,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
+ self.assertEquals(
+ cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
@@ -921,7 +927,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.device("/device:CPU:0"):
a = constant_op.constant([2.0], name="a")
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)
+ out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()