aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/tape_test.py')
-rw-r--r--tensorflow/python/eager/tape_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index b490bac66d..c97cb62125 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -165,6 +166,25 @@ class TapeTest(test.TestCase):
g, = backprop.gradients_function(fn, [0])(t)
self.assertAllEqual(g, 1.0)
+ def testTapeGC(self):
+ # TODO(apassos) figure out how to test this without using tape internal
+ # APIs.
+ tape.push_new_tape()
+
+ def f():
+ x = constant_op.constant(1.0)
+ tape.watch(x)
+ x = gradient_is_constant(x)
+ x = gradient_is_constant(x)
+ x = gradient_is_constant(x)
+
+ f()
+ t = tape.pop_tape()
+ tensor_tape, op_tape = t.export()
+ self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on
+ # the tape
+ self.assertEqual(len(op_tape), 0) # No operations should remain on the tape
+
def testCustomGradientGraphMode(self):
with context.graph_mode(), self.test_session():