diff options
author | Jacques Pienaar <jpienaar@google.com> | 2018-04-11 09:53:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-11 09:55:43 -0700 |
commit | 8f753859dd50a4c8d25b99a7b57c61e0e5c20578 (patch) | |
tree | d228dda8171c854b58cb1f32359406c0c98551d8 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 08a12ca6016c34d9476d2e93bd0f2dc9ae60abc5 (diff) |
Add gradient in cond test to match CallGradInLoop.
PiperOrigin-RevId: 192463997
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 75f8644f69..e27eb00818 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -664,6 +664,23 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) + def testCondGrad_3(self): + with self.test_session(): + c = array_ops.placeholder(dtypes.int32, shape=[]) + ox = constant_op.constant(10.0) + pred = math_ops.less(c, 2) + + def fn1(x): + m = x * x + return gradients_impl.gradients(m, [ox])[0] + + fn2 = lambda: math_ops.multiply(ox, 3.0) + y = math_ops.multiply(7.0, ox) + r = control_flow_ops.cond(pred, lambda: fn1(y), fn2) + + self.assertAllEqual(980.0, r.eval(feed_dict={c: 1})) + self.assertAllEqual(30.0, r.eval(feed_dict={c: 3})) + def testNestedCond_Simple(self): with self.test_session(): x = constant_op.constant(0., name="X") |