aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2017-12-15 11:39:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 11:43:21 -0800
commit8d3690c5649fb6dac481e15eda365e73aeaab84a (patch)
treebf49e45897b518afa5e8842ea97ee434cb7b1772
parentaf5a45260eb9393195dd8c02de7a258300e3ea90 (diff)
Plug an eager memory leak, add tests for reference counts.
There are still some slightly less serious leaks. Will follow up with a fix once I track those down. PiperOrigin-RevId: 179220052
-rw-r--r--tensorflow/c/eager/tape.h5
-rw-r--r--tensorflow/python/BUILD15
-rw-r--r--tensorflow/python/eager/backprop_test.py30
-rw-r--r--tensorflow/python/framework/test_util.py64
-rw-r--r--tensorflow/python/framework/test_util_test.py20
5 files changed, 131 insertions, 3 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 20ed037c52..17c9c8cc9a 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -530,6 +530,11 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
if (!persistent_) {
vspace.ReleaseBackwardFunction(trace.backward_function);
}
+ for (Gradient* grad : out_gradients) {
+ if (grad != nullptr) {
+ vspace.DeleteGradient(grad);
+ }
+ }
}
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
<< trace.input_tensor_id.size() << " sources";
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 45383eda99..80f3ec6681 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -800,15 +800,23 @@ py_library(
srcs = ["framework/test_util.py"],
srcs_version = "PY2AND3",
deps = [
+ ":array_ops",
":client",
":errors",
- ":framework",
":framework_for_generated_wrappers",
":platform",
":platform_test",
":pywrap_tensorflow",
+ ":random_seed",
+ ":resource_variable_ops",
+ ":session",
":training",
":util",
+ ":variables",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:tape",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -1215,6 +1223,11 @@ py_test(
":framework_test_lib",
":platform_test",
":random_ops",
+ ":resource_variable_ops",
+ ":session",
+ ":variables",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 90c0e47ff9..7c44d55467 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradients
@@ -151,6 +152,7 @@ class BackpropTest(test.TestCase):
opt.apply_gradients([(grad, embedding)])
self.assertAllClose(expected, embedding.read_value())
+ @test_util.assert_no_new_tensors
def testGradientNone(self):
def loss(x, l):
@@ -165,6 +167,7 @@ class BackpropTest(test.TestCase):
g, = backprop.gradients_function(loss, [0])(logits, labels)
self.assertAllEqual(g.numpy(), [[-0.5, 0.5]])
+ @test_util.assert_no_new_tensors
def testSecondGrad(self):
def first(x):
@@ -181,6 +184,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(second, [0])(f)[0]
self.assertAllEqual([[0.0]], grad)
+ @test_util.assert_no_new_tensors
def testMakeVJP(self):
def f(x):
@@ -191,6 +195,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(result, 9.0)
self.assertAllEqual(vjp(2.0)[0], 12.0)
+ @test_util.assert_no_new_tensors
def testGradGrad(self):
def sq(x):
@@ -204,6 +209,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(gradgrad(constant_op.constant(3.0))[0], 2.0)
+ @test_util.assert_no_new_tensors
def testGradGradExp(self):
def grad(x):
@@ -214,11 +220,13 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(gradgrad(constant_op.constant(0.0))[0], 1.0)
+ @test_util.assert_no_new_tensors
def testStopGradient(self):
grad = backprop.gradients_function(
lambda x: array_ops.stop_gradient(math_ops.argmax(x)))
self.assertAllEqual(grad([0.0])[0], None)
+ @test_util.assert_no_new_tensors
def testArgmax(self):
def argmax(x):
i = math_ops.argmax(x)
@@ -227,6 +235,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(argmax)
self.assertAllEqual(grad([0.0])[0], None)
+ @test_util.assert_no_new_tensors
def testGPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -242,6 +251,8 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(fn, [0])(constant_op.constant(1.0))[0]
self.assertAllEqual(grad, 1.0)
+ # TODO(b/70675592): Fix leaked Tensors in this test.
+ # @test_util.assert_no_new_tensors
def testGPUImplicitGrad(self):
if not context.context().num_gpus():
self.skipTest('No GPU found')
@@ -257,6 +268,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(
backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
+ @test_util.assert_no_new_tensors
def testCPU(self):
def fn(x):
@@ -267,6 +279,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(fn, [0])(constant_op.constant(1.0))[0]
self.assertAllEqual(grad, 1.0)
+ @test_util.assert_no_new_tensors
def testTensorCopyGPU2CPU2GPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -281,6 +294,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(f, [0])(a, b)[0]
self.assertAllEqual(grad, 1.0)
+ @test_util.assert_no_new_tensors
def testEmptyParams(self):
def fn(a, b):
@@ -292,6 +306,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(dx, y.numpy())
self.assertAllEqual(dy, x.numpy())
+ @test_util.assert_no_new_tensors
def testUnconnectedNone(self):
v = resource_variable_ops.ResourceVariable(
1.0, name='testUnconnectedNone')
@@ -302,6 +317,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
+ @test_util.assert_no_new_tensors
def testGradientTape(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
@@ -316,6 +332,7 @@ class BackpropTest(test.TestCase):
grad = g.gradient(y, [x])[0]
self.assertEqual(grad.numpy(), 6.0)
+ @test_util.assert_no_new_tensors
def testGradientTapeGradientCalledMultipleTimes(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
@@ -327,6 +344,7 @@ class BackpropTest(test.TestCase):
RuntimeError, 'GradientTape.gradient can only be called once'):
g.gradient(y, [x])
+ @test_util.assert_no_new_tensors
def testPersistentTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -339,6 +357,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(dy_dx.numpy(), 2*3)
del g
+ @test_util.assert_no_new_tensors
def testPersistentNestedTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -358,6 +377,8 @@ class BackpropTest(test.TestCase):
self.assertEqual(grad.numpy(), 12.0)
del g
+ # TODO(b/70675592): Fix leaked Tensors in this test.
+ # @test_util.assert_no_new_tensors
def testGradientTapeVariable(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
with backprop.GradientTape() as g:
@@ -365,6 +386,7 @@ class BackpropTest(test.TestCase):
grad = g.gradient(y, [v])[0]
self.assertAllEqual(grad, 2.0)
+ @test_util.assert_no_new_tensors
def testEmptyParamsForValueAndGradFunction(self):
def fn(a, b):
return a * b
@@ -377,6 +399,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(dx, y)
self.assertAllEqual(dy, x)
+ @test_util.assert_no_new_tensors
def testNonEmptyParamsForValueAndGradFunction(self):
def fn(a, b):
return a * b
@@ -389,6 +412,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(1, len(grads))
self.assertAllEqual(grads[0], x)
+ @test_util.assert_no_new_tensors
def testTensorCopyCPU2GPU2CPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -473,6 +497,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(backprop.gradients_function(f)(1.0)[0], 3.0)
+ @test_util.assert_no_new_tensors
def testExceptionSafety(self):
def f(unused_x):
@@ -488,6 +513,8 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(backprop.gradients_function(real_f)(1.0)[0], 2.0)
+ # TODO(b/70675592): Fix leaked Tensors in this test.
+ # @test_util.assert_no_new_tensors
def testMultiValueConvertToTensor(self):
x = resource_variable_ops.ResourceVariable(
initial_value=array_ops.constant([1.0]), name='x')
@@ -548,6 +575,7 @@ class BackpropTest(test.TestCase):
initial_value=1., name='testSameObjectForMultipleArguments.Variable')
self.assertAllEqual([1., 1.], np_g(v, v))
+ @test_util.assert_no_new_tensors
def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
@custom_gradient.custom_gradient
@@ -573,6 +601,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(7, grad)
self.assertAllEqual(x, var)
+ @test_util.assert_no_new_tensors
def testCustomGradient(self):
@custom_gradient.custom_gradient
@@ -599,6 +628,7 @@ class BackpropTest(test.TestCase):
var.assign_sub(lr*grad)
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
+ @test_util.assert_no_new_tensors
def testCustomGradientIdentity(self):
@custom_gradient.custom_gradient
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 8875d45a07..7627fb3e69 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -47,6 +47,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
@@ -57,6 +58,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
@@ -455,6 +457,62 @@ class IsolateTest(object):
type_arg, value_arg, traceback_arg)
+def assert_no_new_tensors(f):
+ """Decorator for asserting that no new Tensors persist after a test.
+
+ Mainly useful for checking that code using the Python C API has correctly
+ manipulated reference counts.
+
+ Clears the caches that it knows about, runs the garbage collector, then checks
+ that there are no Tensor or Tensor-like objects still around. This includes
+ Tensors to which something still has a reference (e.g. from missing
+ Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
+ of the objects has __del__ defined).
+
+ Args:
+ f: The test case to run.
+ Returns:
+ The decorated test case.
+ """
+
+ def decorator(self, **kwargs):
+ """Finds existing Tensors, runs the test, checks for new Tensors."""
+
+ def _is_tensor(obj):
+ try:
+ return (isinstance(obj, ops.Tensor) or
+ isinstance(obj, variables.Variable))
+ except ReferenceError:
+ # If the object no longer exists, we don't care about it.
+ return False
+
+ tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj))
+ outside_container_prefix = ops.get_default_graph()._container_prefix
+ with IsolateTest():
+ # Run the test in a new graph so that collections get cleared when it's
+ # done, but inherit the container prefix so that we can print the values
+ # of variables which get leaked when executing eagerly.
+ ops.get_default_graph()._container_prefix = outside_container_prefix
+ f(self, **kwargs)
+ # Make an effort to clear caches, which would otherwise look like leaked
+ # Tensors.
+ backprop._last_zero = [None]
+ backprop._shape_dtype = [None, None]
+ context.get_default_context().scalar_cache().clear()
+ gc.collect()
+ tensors_after = [
+ obj for obj in gc.get_objects()
+ if _is_tensor(obj) and id(obj) not in tensors_before
+ ]
+ if tensors_after:
+ raise AssertionError(("%d Tensors not deallocated after test: %s" % (
+ len(tensors_after),
+ str(tensors_after),
+ )))
+
+ return decorator
+
+
def assert_no_garbage_created(f):
"""Test method decorator to assert that no garbage has been created.
@@ -509,7 +567,8 @@ def run_in_graph_and_eager_modes(
garbage for legitimate reasons (e.g. they define a class which inherits
from `object`), and because DEBUG_SAVEALL is sticky in some Python
interpreters (meaning that tests which rely on objects being collected
- elsewhere in the unit test file will not work).
+ elsewhere in the unit test file will not work). Additionally, checks that
+ nothing still has a reference to Tensors that the test allocated.
Returns:
Returns a decorator that will run the decorated test function
using both a graph and using eager execution.
@@ -546,7 +605,8 @@ def run_in_graph_and_eager_modes(
f(self, **kwargs)
if assert_no_eager_garbage:
- run_eager_mode = assert_no_garbage_created(run_eager_mode)
+ run_eager_mode = assert_no_new_tensors(
+ assert_no_garbage_created(run_eager_mode))
with context.eager_mode():
with IsolateTest():
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 90b5290626..f6aed118ca 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -373,6 +373,26 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
ReferenceCycleTest().test_has_no_cycle()
+ def test_no_leaked_tensor_decorator(self):
+
+ class LeakedTensorTest(object):
+
+ def __init__(inner_self): # pylint: disable=no-self-argument
+ inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_tensors
+ def test_has_leak(self):
+ self.a = constant_op.constant([3.])
+
+ @test_util.assert_no_new_tensors
+ def test_has_no_leak(self):
+ constant_op.constant([3.])
+
+ with self.assertRaisesRegexp(AssertionError, "Tensors not deallocated"):
+ LeakedTensorTest().test_has_leak()
+
+ LeakedTensorTest().test_has_no_leak()
+
@test_util.with_c_api
class IsolationTest(test_util.TensorFlowTestCase):