aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cond_v2_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-08-08 17:33:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 17:41:20 -0700
commita8aa908df4b18bcb597080a3c6f38e86e87c5587 (patch)
tree87e840e3c6c3bd59853b542f30ac75f538dfbc2a /tensorflow/python/kernel_tests/cond_v2_test.py
parentf535f733d7529d2e3a3f231ea6d387529fc899da (diff)
Remove identity ops for ys added during gradient computation. This was added to avoid issues with computing gradients when ys were dependent.
The real issue behind that has however since been fixed so adding identity ops is no longer relevant. PiperOrigin-RevId: 207974344
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 97ce245fc8..4d074218d1 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -78,6 +78,20 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [x, y])
self._testCond(true_fn, false_fn, [y])
+ def testMultipleOutputs(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(3.0, name="y")
+
+ def true_fn():
+ return x * y, y
+
+ def false_fn():
+ return x, y * 3.0
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
def testBasic2(self):
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -104,8 +118,8 @@ class CondV2Test(test.TestCase):
out = cond_v2.cond_v2(pred, true_fn, false_fn)
- self.assertEqual(sess.run(out, {pred: True}), [1.0])
- self.assertEqual(sess.run(out, {pred: False}), [2.0])
+ self.assertEqual(sess.run(out, {pred: True}), (1.0,))
+ self.assertEqual(sess.run(out, {pred: False}), (2.0,))
def _createCond(self, name):
pred = constant_op.constant(True, name="pred")