aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-04-11 09:53:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 09:55:43 -0700
commit8f753859dd50a4c8d25b99a7b57c61e0e5c20578 (patch)
treed228dda8171c854b58cb1f32359406c0c98551d8 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent08a12ca6016c34d9476d2e93bd0f2dc9ae60abc5 (diff)
Add gradient in cond test to match CallGradInLoop.
PiperOrigin-RevId: 192463997
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 75f8644f69..e27eb00818 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -664,6 +664,23 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ def testCondGrad_3(self):
+ with self.test_session():
+ c = array_ops.placeholder(dtypes.int32, shape=[])
+ ox = constant_op.constant(10.0)
+ pred = math_ops.less(c, 2)
+
+ def fn1(x):
+ m = x * x
+ return gradients_impl.gradients(m, [ox])[0]
+
+ fn2 = lambda: math_ops.multiply(ox, 3.0)
+ y = math_ops.multiply(7.0, ox)
+ r = control_flow_ops.cond(pred, lambda: fn1(y), fn2)
+
+ self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
+ self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
+
def testNestedCond_Simple(self):
with self.test_session():
x = constant_op.constant(0., name="X")