diff options
author | 2017-12-15 13:18:52 -0800 | |
---|---|---|
committer | 2017-12-15 13:22:51 -0800 | |
commit | 6373db931a8444083d12e7ac08d4d62c859a6c64 (patch) | |
tree | 6c970fba9ba0ae95fdff0c7f4dba632fd8b55185 | |
parent | 83b3350baf9c9fc7e733c40df44f2e3f11e23d93 (diff) |
Fix reference counts when watching variables (eager tape)
Adds unit test assertions that variables are properly dealloacted once the tape is deleted and there are no remaining Python references.
PiperOrigin-RevId: 179231752
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 15 |
2 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 7c44d55467..3da22d4c34 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -251,8 +251,7 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(fn, [0])(constant_op.constant(1.0))[0] self.assertAllEqual(grad, 1.0) - # TODO(b/70675592): Fix leaked Tensors in this test. - # @test_util.assert_no_new_tensors + @test_util.assert_no_new_tensors def testGPUImplicitGrad(self): if not context.context().num_gpus(): self.skipTest('No GPU found') @@ -377,8 +376,7 @@ class BackpropTest(test.TestCase): self.assertEqual(grad.numpy(), 12.0) del g - # TODO(b/70675592): Fix leaked Tensors in this test. - # @test_util.assert_no_new_tensors + @test_util.assert_no_new_tensors def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') with backprop.GradientTape() as g: @@ -513,8 +511,7 @@ class BackpropTest(test.TestCase): self.assertAllEqual(backprop.gradients_function(real_f)(1.0)[0], 2.0) - # TODO(b/70675592): Fix leaked Tensors in this test. - # @test_util.assert_no_new_tensors + @test_util.assert_no_new_tensors def testMultiValueConvertToTensor(self): x = resource_variable_ops.ResourceVariable( initial_value=array_ops.constant([1.0]), name='x') diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index b52d71dc6c..3ba81fb3d0 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -472,9 +472,19 @@ class GradientTape explicit GradientTape(bool persistent) : tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {} + virtual ~GradientTape() { + for (PyObject* v : watched_variables_) { + Py_DECREF(v); + } + } + void WatchVariable(PyObject* v) { - watched_variables_.insert(v); - Py_INCREF(v); + auto insert_result = watched_variables_.insert(v); + if (insert_result.second) { + // Only increment the reference count if we aren't already watching this + // variable. + Py_INCREF(v); + } PyObject* handle = PyObject_GetAttrString(v, "handle"); if (handle == nullptr) { return; @@ -722,7 +732,6 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { PyObject* result = PySet_New(nullptr); for (PyObject* variable : watched_variables) { PySet_Add(result, variable); - Py_DECREF(variable); } return result; } |