diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-24 21:29:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 21:34:13 -0700 |
commit | c1644948d23cae271b140d67101c1a386e5495fd (patch) | |
tree | 002efca36c4f95f75b08358343c3701de014880b /tensorflow/python | |
parent | 9875df75c308d7498e601ae9a4b57db6aad47056 (diff) |
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 31 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 23 | ||||
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 2 |
4 files changed, 24 insertions, 38 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 5c0e24117f..377c041675 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -131,7 +131,7 @@ class CondV2Test(test.TestCase): def false_fn(): return x + 1 - return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name).op def testDefaultName(self): with ops.Graph().as_default(): @@ -569,8 +569,7 @@ class CondV2Test(test.TestCase): ops.add_to_collection("pred", pred) cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") - for c in cond: - ops.add_to_collection("cond", c) + ops.add_to_collection("cond", cond) meta_graph = saver.export_meta_graph() with ops.Graph().as_default() as g: @@ -672,7 +671,7 @@ class CondV2CollectionTest(test.TestCase): return math_ops.add(x_const, y_const) cnd = cond_v2.cond_v2(True, fn, fn) - self.assertEquals(cnd[0].eval(), 7) + self.assertEquals(cnd.eval(), 7) def testCollectionTensorValueAccessInCond(self): """Read tensors from collections inside of cond_v2 & use them.""" @@ -689,7 +688,7 @@ class CondV2CollectionTest(test.TestCase): return math_ops.add(x_read, y_read) cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) - self.assertEquals(cnd[0].eval(), 7) + self.assertEquals(cnd.eval(), 7) def testCollectionIntValueWriteInCond(self): """Make sure Int writes to collections work inside of cond_v2.""" @@ -709,7 +708,7 @@ class CondV2CollectionTest(test.TestCase): cnd = cond_v2.cond_v2( True, true_fn, false_fn) - self.assertEquals(cnd[0].eval(), 14) + self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") self.assertEquals(read_z_collection, [7]) @@ -782,10 +781,10 @@ class CondV2ContainerTest(test.TestCase): with ops.container("l1"): cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) - self.assertEquals(cnd_true[0].eval(), 2) + self.assertEquals(cnd_true.eval(), 2) cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) - self.assertEquals(cnd_false[0].eval(), 6) + self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) @@ -813,7 +812,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -822,7 +821,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(a.op): with ops.colocate_with(b.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) def testColocateWithInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -838,7 +837,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.colocate_with(a.op): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) @@ -859,7 +858,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.colocate_with(b.op): c = math_ops.add(a, a, name="c") return c - out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + out_cond_2 = cond_v2.cond_v2(True, fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() @@ -881,7 +880,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) @@ -889,7 +888,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:GPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: @@ -903,7 +902,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): return c with ops.device("/device:CPU:0"): - self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) @@ -922,7 +921,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") - out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + out_cond_2 = cond_v2.cond_v2(True, fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 2996539004..fc4d2a3809 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -422,8 +422,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r.values.get_shape(), (2,)) def testCondResource(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): rv = resource_variable_ops.ResourceVariable(True) @@ -484,15 +482,12 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(11, result) def testCond_1(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") self._testCond_1(use_gpu=False) - self._testCond_1(use_gpu=True) + # TODO(b/116526896): Enable GPU tests. + # self._testCond_1(use_gpu=True) def testCond_2(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -503,8 +498,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(9, result) def testCond_3(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -556,8 +549,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(4, count.eval()) def testCond_6(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): v1 = variables.Variable([7]) @@ -583,8 +574,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([11, 12], sess.run(r)) def testCondRef(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = gen_state_ops.variable( @@ -1444,7 +1433,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_1(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0, name="n") @@ -1457,7 +1446,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_2(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0) @@ -2633,8 +2622,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual(5.0, result.eval()) def testOneValueCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): c = array_ops.placeholder(dtypes.int32, shape=[]) @@ -2651,8 +2638,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual([2], i.eval(feed_dict={c: 0})) def testExampleCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = ops.convert_to_tensor([-2.0, 2.0], name="x") diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index c6a6b2a7fa..f8b1ddb140 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access - return tuple(tensors[:num_cond_outputs]) + result = tuple(tensors[:num_cond_outputs]) + if len(result) == 1: + return result[0] + else: + return result @ops.RegisterGradient("If") diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 208b56e909..1c75aab578 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -329,8 +329,6 @@ def _random_flip(image, flip_index, seed, scope_name): lambda: image, name=scope ) - if isinstance(result, tuple): - result = result[0] # TODO(b/111124878) remove this logic (CondV2). return fix_image_flip_shape(image, result) elif shape.ndims == 4: batch_size = array_ops.shape(image)[0] |