aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/backprop.py9
-rw-r--r--tensorflow/python/eager/backprop_test.py12
2 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index deac29111f..44ce69ee60 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -868,6 +868,7 @@ class GradientTape(object):
Raises:
RuntimeError: if called inside the context of the tape, or if called more
than once on a non-persistent tape.
+ ValueError: if called on variable target.
"""
if self._tape is None:
raise RuntimeError("GradientTape.gradient can only be called once on "
@@ -887,6 +888,12 @@ class GradientTape(object):
"gradient in order to compute higher order "
"derrivatives.", 1)
+ flat_targets = nest.flatten(target)
+ for t in flat_targets:
+ if resource_variable_ops.is_resource_variable(t):
+ raise ValueError("GradientTape.gradient is not supported for variable "
+ "targets.")
+
flat_sources = nest.flatten(sources)
flat_sources = [_handle_or_self(x) for x in flat_sources]
@@ -896,7 +903,7 @@ class GradientTape(object):
flat_grad = imperative_grad.imperative_grad(
self._tape,
- nest.flatten(target),
+ flat_targets,
flat_sources,
output_gradients=output_gradients)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 32731747b7..7e5c9f3cb6 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -548,6 +548,17 @@ class BackpropTest(test.TestCase):
grad = g.gradient(y, [x])[0]
self.assertEqual(self.evaluate(grad), 6.0)
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes
+ def testGadientTapeCalledOnConstantTarget(self):
+ with backprop.GradientTape() as g:
+ x = variables.Variable([3.0])
+ y = variables.Variable([2.0])
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'GradientTape.gradient is not supported for variable targets.'):
+ g.gradient(x, y)
+
@test_util.run_in_graph_and_eager_modes
def testGradientTapeWithCond(self):
x = constant_op.constant(3.0)
@@ -982,7 +993,6 @@ class BackpropTest(test.TestCase):
self.assertIsNone(dy)
self.assertEqual(self.evaluate(dz), 3.0)
-
@test_util.run_in_graph_and_eager_modes
def testDifferentiatingScalarCache(self):
# In the following test, if x2 = x1 (i.e the objects are the exact same),