diff options
author | 2018-09-24 21:29:42 -0700 | |
---|---|---|
committer | 2018-09-24 21:34:13 -0700 | |
commit | c1644948d23cae271b140d67101c1a386e5495fd (patch) | |
tree | 002efca36c4f95f75b08358343c3701de014880b /tensorflow/python/kernel_tests | |
parent | 9875df75c308d7498e601ae9a4b57db6aad47056 (diff) |
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 31 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 23 |
2 files changed, 19 insertions, 35 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 5c0e24117f..377c041675 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -131,7 +131,7 @@ class CondV2Test(test.TestCase): def false_fn(): return x + 1 - return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name).op def testDefaultName(self): with ops.Graph().as_default(): @@ -569,8 +569,7 @@ class CondV2Test(test.TestCase): ops.add_to_collection("pred", pred) cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") - for c in cond: - ops.add_to_collection("cond", c) + ops.add_to_collection("cond", cond) meta_graph = saver.export_meta_graph() with ops.Graph().as_default() as g: @@ -672,7 +671,7 @@ class CondV2CollectionTest(test.TestCase): return math_ops.add(x_const, y_const) cnd = cond_v2.cond_v2(True, fn, fn) - self.assertEquals(cnd[0].eval(), 7) + self.assertEquals(cnd.eval(), 7) def testCollectionTensorValueAccessInCond(self): """Read tensors from collections inside of cond_v2 & use them.""" @@ -689,7 +688,7 @@ class CondV2CollectionTest(test.TestCase): return math_ops.add(x_read, y_read) cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) - self.assertEquals(cnd[0].eval(), 7) + self.assertEquals(cnd.eval(), 7) def testCollectionIntValueWriteInCond(self): """Make sure Int writes to collections work inside of cond_v2.""" @@ -709,7 +708,7 @@ class CondV2CollectionTest(test.TestCase): cnd = cond_v2.cond_v2( True, true_fn, false_fn) - self.assertEquals(cnd[0].eval(), 14) + self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") self.assertEquals(read_z_collection, [7]) @@ -782,10 +781,10 @@ class CondV2ContainerTest(test.TestCase): with ops.container("l1"): cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) - self.assertEquals(cnd_true[0].eval(), 2) + self.assertEquals(cnd_true.eval(), 2) cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) - self.assertEquals(cnd_false[0].eval(), 6) + self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) @@ -813,7 +812,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -822,7 +821,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(a.op): with ops.colocate_with(b.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) def testColocateWithInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -838,7 +837,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) @@ -859,7 +858,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(b.op): c = math_ops.add(a, a, name="c") return c - out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + out_cond_2 = cond_v2.cond_v2(True, fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() @@ -881,7 +880,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -889,7 +888,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:GPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -903,7 +902,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) @@ -922,7 +921,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") - out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + out_cond_2 = cond_v2.cond_v2(True, fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 2996539004..fc4d2a3809 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -422,8 +422,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r.values.get_shape(), (2,)) def testCondResource(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): rv = resource_variable_ops.ResourceVariable(True) @@ -484,15 +482,12 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(11, result) def testCond_1(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") self._testCond_1(use_gpu=False) - self._testCond_1(use_gpu=True) + # TODO(b/116526896): Enable GPU tests. + # self._testCond_1(use_gpu=True) def testCond_2(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -503,8 +498,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(9, result) def testCond_3(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -556,8 +549,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(4, count.eval()) def testCond_6(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): v1 = variables.Variable([7]) @@ -583,8 +574,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([11, 12], sess.run(r)) def testCondRef(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = gen_state_ops.variable( @@ -1444,7 +1433,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_1(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0, name="n") @@ -1457,7 +1446,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_2(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0) @@ -2633,8 +2622,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual(5.0, result.eval()) def testOneValueCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): c = array_ops.placeholder(dtypes.int32, shape=[]) @@ -2651,8 +2638,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual([2], i.eval(feed_dict={c: 0})) def testExampleCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = ops.convert_to_tensor([-2.0, 2.0], name="x") |