aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-24 21:29:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 21:34:13 -0700
commitc1644948d23cae271b140d67101c1a386e5495fd (patch)
tree002efca36c4f95f75b08358343c3701de014880b /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent9875df75c308d7498e601ae9a4b57db6aad47056 (diff)
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py23
1 files changed, 4 insertions, 19 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 2996539004..fc4d2a3809 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -422,8 +422,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
rv = resource_variable_ops.ResourceVariable(True)
@@ -484,15 +482,12 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
- self._testCond_1(use_gpu=True)
+ # TODO(b/116526896): Enable GPU tests.
+ # self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -503,8 +498,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -556,8 +549,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
v1 = variables.Variable([7])
@@ -583,8 +574,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = gen_state_ops.variable(
@@ -1444,7 +1433,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_1(self):
if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1457,7 +1446,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_2(self):
if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -2633,8 +2622,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2651,8 +2638,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")