aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-05-07 17:21:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 18:26:54 -0700
commit482ed8eb666d8bc1e5c3f47e5c1e61cc19e0fdb1 (patch)
tree8ac8a60814358a1bf853433f5df01205d0961663 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentb67d8b278d48a046491b42eccbd5c5c23975d054 (diff)
Raise an error if we try to take the gradient wrt to the initial value of a loop variable.
Fixes #14101 PiperOrigin-RevId: 195748688
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 77e6f5f1a0..843759fed0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1847,6 +1847,23 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ def testGradInWhileWrtInitialLoopVal(self):
+ with self.test_session():
+ x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
+ y = x + 1
+
+ def body(i, v):
+ z = v * 2
+ return i + 1, gradients_impl.gradients(z, x)[0]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot compute gradient inside while loop with respect to op 'x'. "
+ "We do not support taking the gradient wrt or through the initial "
+ "value of a loop variable. Gradients can be computed through "
+ "loop invariants or wrt the input parameters to the loop body."):
+ control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+
def testWhileGradInWhile(self):
with self.test_session():
n = ops.convert_to_tensor(1.0, name="n")