aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-19 14:16:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-19 15:43:16 -0700
commiteb978292e0ac46dd16c820b9989ad1776295517a (patch)
tree6eb7fe22e62da6ed4d68ca01099f3f5cc49bba21 /tensorflow/python/eager/backprop_test.py
parent7fe3744373751ee6a79bb23c6c20343a91d07b28 (diff)
Context-manager-based gradient API
PiperOrigin-RevId: 172796719
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 5161095683..95d5f0adcb 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -277,6 +277,27 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
+ def testGradientTape(self):
+ with backprop.GradientTape() as g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = x * x
+ with backprop.GradientTape() as gg:
+ gg.watch(y)
+ z = 2 * y
+ inner_grad = gg.gradient(z, [y])[0]
+ self.assertEqual(inner_grad.numpy(), 2.0)
+ y += inner_grad
+ grad = g.gradient(y, [x])[0]
+ self.assertEqual(grad.numpy(), 6.0)
+
+ def testGradientTapeVariable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape() as g:
+ y = v * v
+ grad = g.gradient(y, [v])[0]
+ self.assertAllEqual(grad, 2.0)
+
def testEmptyParamsForValueAndGradFunction(self):
def fn(a, b):
return a * b