aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-24 21:29:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 21:34:13 -0700
commitc1644948d23cae271b140d67101c1a386e5495fd (patch)
tree002efca36c4f95f75b08358343c3701de014880b /tensorflow/python/kernel_tests
parent9875df75c308d7498e601ae9a4b57db6aad47056 (diff)
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py31
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py23
2 files changed, 19 insertions, 35 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 5c0e24117f..377c041675 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -131,7 +131,7 @@ class CondV2Test(test.TestCase):
def false_fn():
return x + 1
- return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name).op
def testDefaultName(self):
with ops.Graph().as_default():
@@ -569,8 +569,7 @@ class CondV2Test(test.TestCase):
ops.add_to_collection("pred", pred)
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
- for c in cond:
- ops.add_to_collection("cond", c)
+ ops.add_to_collection("cond", cond)
meta_graph = saver.export_meta_graph()
with ops.Graph().as_default() as g:
@@ -672,7 +671,7 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_const, y_const)
cnd = cond_v2.cond_v2(True, fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self):
"""Read tensors from collections inside of cond_v2 & use them."""
@@ -689,7 +688,7 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_read, y_read)
cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionIntValueWriteInCond(self):
"""Make sure Int writes to collections work inside of cond_v2."""
@@ -709,7 +708,7 @@ class CondV2CollectionTest(test.TestCase):
cnd = cond_v2.cond_v2(
True, true_fn,
false_fn)
- self.assertEquals(cnd[0].eval(), 14)
+ self.assertEquals(cnd.eval(), 14)
read_z_collection = ops.get_collection("z")
self.assertEquals(read_z_collection, [7])
@@ -782,10 +781,10 @@ class CondV2ContainerTest(test.TestCase):
with ops.container("l1"):
cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
- self.assertEquals(cnd_true[0].eval(), 2)
+ self.assertEquals(cnd_true.eval(), 2)
cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
- self.assertEquals(cnd_false[0].eval(), 6)
+ self.assertEquals(cnd_false.eval(), 6)
v4 = variables.Variable([3])
q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
@@ -813,7 +812,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -822,7 +821,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testColocateWithInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -838,7 +837,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
@@ -859,7 +858,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(b.op):
c = math_ops.add(a, a, name="c")
return c
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
@@ -881,7 +880,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -889,7 +888,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:GPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -903,7 +902,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
@@ -922,7 +921,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.device("/device:CPU:0"):
a = constant_op.constant([2.0], name="a")
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 2996539004..fc4d2a3809 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -422,8 +422,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
rv = resource_variable_ops.ResourceVariable(True)
@@ -484,15 +482,12 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
- self._testCond_1(use_gpu=True)
+ # TODO(b/116526896): Enable GPU tests.
+ # self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -503,8 +498,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -556,8 +549,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
v1 = variables.Variable([7])
@@ -583,8 +574,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = gen_state_ops.variable(
@@ -1444,7 +1433,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_1(self):
if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1457,7 +1446,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_2(self):
if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -2633,8 +2622,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2651,8 +2638,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")