diff options
author | Saurabh Saxena <srbs@google.com> | 2018-08-08 17:33:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:41:20 -0700 |
commit | a8aa908df4b18bcb597080a3c6f38e86e87c5587 (patch) | |
tree | 87e840e3c6c3bd59853b542f30ac75f538dfbc2a /tensorflow/python/kernel_tests/cond_v2_test.py | |
parent | f535f733d7529d2e3a3f231ea6d387529fc899da (diff) |
Remove identity ops for ys added during gradient computation. This was added to avoid issues with computing gradients when ys were dependent.
The real issue behind that has however since been fixed so adding identity ops is no longer relevant.
PiperOrigin-RevId: 207974344
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 97ce245fc8..4d074218d1 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -78,6 +78,20 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [x, y]) self._testCond(true_fn, false_fn, [y]) + def testMultipleOutputs(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(3.0, name="y") + + def true_fn(): + return x * y, y + + def false_fn(): + return x, y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + def testBasic2(self): x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -104,8 +118,8 @@ class CondV2Test(test.TestCase): out = cond_v2.cond_v2(pred, true_fn, false_fn) - self.assertEqual(sess.run(out, {pred: True}), [1.0]) - self.assertEqual(sess.run(out, {pred: False}), [2.0]) + self.assertEqual(sess.run(out, {pred: True}), (1.0,)) + self.assertEqual(sess.run(out, {pred: False}), (2.0,)) def _createCond(self, name): pred = constant_op.constant(True, name="pred") |