aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-07 13:48:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 13:52:34 -0700
commitbcc64f892a2fb264cbd92dedbe68d6fc779f2ea6 (patch)
tree985530fb787c202e9bb5a4ba2d67a1024d11d8db /tensorflow/python/eager/backprop_test.py
parent6d25ff7a641772ccbe508d5df200aeddc101c028 (diff)
Fix issue where re-entering a GradientTape context clears the tape
* Changed behavior of GradientTape._push_tape() such that it always uses the existing tape if one is present. * GradientTape.reset() clears the tape * Added testGradientTapeReEnterContext to backprop_test.py PiperOrigin-RevId: 212029862
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 65d57d3957..f938ed5df8 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -474,6 +474,18 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
@test_util.assert_no_new_tensors
+ def testGradientTapeReEnterContext(self):
+ g = backprop.GradientTape()
+ with g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2*x
+ with g:
+ z = 2*y
+ grad = g.gradient(target=z, sources=[x])
+ self.assertEqual(self.evaluate(grad), [4.0])
+
+ @test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes
def testGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=False) as g: