diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-07 13:48:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 13:52:34 -0700 |
commit | bcc64f892a2fb264cbd92dedbe68d6fc779f2ea6 (patch) | |
tree | 985530fb787c202e9bb5a4ba2d67a1024d11d8db /tensorflow/python/eager/backprop_test.py | |
parent | 6d25ff7a641772ccbe508d5df200aeddc101c028 (diff) |
Fix issue where re-entering a GradientTape context clears the tape
* Changed behavior of GradientTape._push_tape() such that it always uses
the existing tape if one is present.
* GradientTape.reset() clears the tape
* Added testGradientTapeReEnterContext to backprop_test.py
PiperOrigin-RevId: 212029862
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 65d57d3957..f938ed5df8 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -474,6 +474,18 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors + def testGradientTapeReEnterContext(self): + g = backprop.GradientTape() + with g: + x = constant_op.constant(3.0) + g.watch(x) + y = 2*x + with g: + z = 2*y + grad = g.gradient(target=z, sources=[x]) + self.assertEqual(self.evaluate(grad), [4.0]) + + @test_util.assert_no_new_tensors @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: |