From df11cce2e600581087f29ef0b85286f7e582572d Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Tue, 9 Oct 2018 09:18:53 -0700 Subject: Throw error when evaluating have variable target in GradientTape. PiperOrigin-RevId: 216368178 --- tensorflow/python/eager/backprop.py | 9 ++++++++- tensorflow/python/eager/backprop_test.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index deac29111f..44ce69ee60 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -868,6 +868,7 @@ class GradientTape(object): Raises: RuntimeError: if called inside the context of the tape, or if called more than once on a non-persistent tape. + ValueError: if called on variable target. """ if self._tape is None: raise RuntimeError("GradientTape.gradient can only be called once on " @@ -887,6 +888,12 @@ class GradientTape(object): "gradient in order to compute higher order " "derrivatives.", 1) + flat_targets = nest.flatten(target) + for t in flat_targets: + if resource_variable_ops.is_resource_variable(t): + raise ValueError("GradientTape.gradient is not supported for variable " + "targets.") + flat_sources = nest.flatten(sources) flat_sources = [_handle_or_self(x) for x in flat_sources] @@ -896,7 +903,7 @@ class GradientTape(object): flat_grad = imperative_grad.imperative_grad( self._tape, - nest.flatten(target), + flat_targets, flat_sources, output_gradients=output_gradients) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 32731747b7..7e5c9f3cb6 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -548,6 +548,17 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(grad), 6.0) + @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes + def testGadientTapeCalledOnConstantTarget(self): + with backprop.GradientTape() as g: + x = variables.Variable([3.0]) + y = variables.Variable([2.0]) + with self.assertRaisesRegexp( + ValueError, + 'GradientTape.gradient is not supported for variable targets.'): + g.gradient(x, y) + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithCond(self): x = constant_op.constant(3.0) @@ -982,7 +993,6 @@ class BackpropTest(test.TestCase): self.assertIsNone(dy) self.assertEqual(self.evaluate(dz), 3.0) - @test_util.run_in_graph_and_eager_modes def testDifferentiatingScalarCache(self): # In the following test, if x2 = x1 (i.e the objects are the exact same), -- cgit v1.2.3