diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index a424a0f219..0e7c2f8ae6 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -670,7 +670,7 @@ class CondV2CollectionTest(test.TestCase): y_const = constant_op.constant(ops.get_collection("y")[0]) return math_ops.add(x_const, y_const) - cnd = cond_v2.cond_v2(True, fn, fn) + cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn) self.assertEquals(cnd.eval(), 7) def testCollectionTensorValueAccessInCond(self): @@ -705,9 +705,7 @@ class CondV2CollectionTest(test.TestCase): z = math_ops.add(x, y) return math_ops.mul(x, z) - cnd = cond_v2.cond_v2( - True, true_fn, - false_fn) + cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") @@ -780,10 +778,12 @@ class CondV2ContainerTest(test.TestCase): return constant_op.constant(6.0) with ops.container("l1"): - cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + cnd_true = cond_v2.cond_v2( + constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd_true.eval(), 2) - cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + cnd_false = cond_v2.cond_v2( + constant_op.constant(False), true_fn, false_fn) self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) @@ -812,7 +812,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -821,7 +822,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(a.op): with ops.colocate_with(b.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) def testColocateWithInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -837,7 +839,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) @@ -858,7 +861,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(b.op): c = math_ops.add(a, a, name="c") return c - out_cond_2 = cond_v2.cond_v2(True, fn, fn) + out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() @@ -880,7 +883,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -888,7 +892,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:GPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -902,7 +907,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) + self.assertEquals( + cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) @@ -921,7 +927,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") - out_cond_2 = cond_v2.cond_v2(True, fn, fn) + out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() |