diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-24 21:29:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 21:34:13 -0700 |
commit | c1644948d23cae271b140d67101c1a386e5495fd (patch) | |
tree | 002efca36c4f95f75b08358343c3701de014880b /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 9875df75c308d7498e601ae9a4b57db6aad47056 (diff) |
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 23 |
1 files changed, 4 insertions, 19 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 2996539004..fc4d2a3809 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -422,8 +422,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r.values.get_shape(), (2,)) def testCondResource(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): rv = resource_variable_ops.ResourceVariable(True) @@ -484,15 +482,12 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(11, result) def testCond_1(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") self._testCond_1(use_gpu=False) - self._testCond_1(use_gpu=True) + # TODO(b/116526896): Enable GPU tests. + # self._testCond_1(use_gpu=True) def testCond_2(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -503,8 +498,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(9, result) def testCond_3(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = constant_op.constant(10) @@ -556,8 +549,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(4, count.eval()) def testCond_6(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): v1 = variables.Variable([7]) @@ -583,8 +574,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([11, 12], sess.run(r)) def testCondRef(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = gen_state_ops.variable( @@ -1444,7 +1433,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_1(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0, name="n") @@ -1457,7 +1446,7 @@ class ControlFlowTest(test.TestCase): def testCondWhile_2(self): if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") + return unittest.skip("b/113294340 (enable while_v2)") with self.cached_session(): n = ops.convert_to_tensor(0) @@ -2633,8 +2622,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual(5.0, result.eval()) def testOneValueCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): c = array_ops.placeholder(dtypes.int32, shape=[]) @@ -2651,8 +2638,6 @@ class ControlFlowTest(test.TestCase): self.assertEqual([2], i.eval(feed_dict={c: 0})) def testExampleCond(self): - if control_flow_ops.ENABLE_COND_V2: - return unittest.skip("b/111124878 (don't return tuple)") with self.cached_session(): x = ops.convert_to_tensor([-2.0, 2.0], name="x") |