diff options
Diffstat (limited to 'tensorflow/python/eager/tape_test.py')
-rw-r--r-- | tensorflow/python/eager/tape_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index b490bac66d..c97cb62125 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient +from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -165,6 +166,25 @@ class TapeTest(test.TestCase): g, = backprop.gradients_function(fn, [0])(t) self.assertAllEqual(g, 1.0) + def testTapeGC(self): + # TODO(apassos) figure out how to test this without using tape internal + # APIs. + tape.push_new_tape() + + def f(): + x = constant_op.constant(1.0) + tape.watch(x) + x = gradient_is_constant(x) + x = gradient_is_constant(x) + x = gradient_is_constant(x) + + f() + t = tape.pop_tape() + tensor_tape, op_tape = t.export() + self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on + # the tape + self.assertEqual(len(op_tape), 0) # No operations should remain on the tape + def testCustomGradientGraphMode(self): with context.graph_mode(), self.test_session(): |