diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-19 14:16:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-19 15:43:16 -0700 |
commit | eb978292e0ac46dd16c820b9989ad1776295517a (patch) | |
tree | 6eb7fe22e62da6ed4d68ca01099f3f5cc49bba21 /tensorflow/python/eager/backprop_test.py | |
parent | 7fe3744373751ee6a79bb23c6c20343a91d07b28 (diff) |
Context-manager-based gradient API
PiperOrigin-RevId: 172796719
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 5161095683..95d5f0adcb 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -277,6 +277,27 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) + def testGradientTape(self): + with backprop.GradientTape() as g: + x = constant_op.constant(3.0) + g.watch(x) + y = x * x + with backprop.GradientTape() as gg: + gg.watch(y) + z = 2 * y + inner_grad = gg.gradient(z, [y])[0] + self.assertEqual(inner_grad.numpy(), 2.0) + y += inner_grad + grad = g.gradient(y, [x])[0] + self.assertEqual(grad.numpy(), 6.0) + + def testGradientTapeVariable(self): + v = resource_variable_ops.ResourceVariable(1.0) + with backprop.GradientTape() as g: + y = v * v + grad = g.gradient(y, [v])[0] + self.assertAllEqual(grad, 2.0) + def testEmptyParamsForValueAndGradFunction(self): def fn(a, b): return a * b |