diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-05-07 17:21:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 18:26:54 -0700 |
commit | 482ed8eb666d8bc1e5c3f47e5c1e61cc19e0fdb1 (patch) | |
tree | 8ac8a60814358a1bf853433f5df01205d0961663 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | b67d8b278d48a046491b42eccbd5c5c23975d054 (diff) |
Raise an error if we try to take the gradient wrt to the initial value of a loop variable.
Fixes #14101
PiperOrigin-RevId: 195748688
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 77e6f5f1a0..843759fed0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1847,6 +1847,23 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) + def testGradInWhileWrtInitialLoopVal(self): + with self.test_session(): + x = array_ops.placeholder(dtypes.float32, shape=(), name="x") + y = x + 1 + + def body(i, v): + z = v * 2 + return i + 1, gradients_impl.gradients(z, x)[0] + + with self.assertRaisesRegexp( + ValueError, + "Cannot compute gradient inside while loop with respect to op 'x'. " + "We do not support taking the gradient wrt or through the initial " + "value of a loop variable. Gradients can be computed through " + "loop invariants or wrt the input parameters to the loop body."): + control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y]) + def testWhileGradInWhile(self): with self.test_session(): n = ops.convert_to_tensor(1.0, name="n") |