diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/py_func_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/py_func_test.py | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 677253946e..253e43920b 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc import re import numpy as np @@ -434,13 +435,29 @@ class PyFuncTest(test.TestCase): # ----- Tests shared by py_func and eager_py_func ----- def testCleanup(self): - for _ in xrange(1000): - g = ops.Graph() - with g.as_default(): - c = constant_op.constant([1.], dtypes.float32) - _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) - _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) - self.assertLess(script_ops._py_funcs.size(), 100) + # Delete everything created by previous tests to avoid side effects. + ops.reset_default_graph() + gc.collect() + initial_size = script_ops._py_funcs.size() + # Encapsulate the graph generation, so locals can be deleted. + def make_graphs(): + for _ in xrange(1000): + g = ops.Graph() + with g.as_default(): + c = constant_op.constant([1.], dtypes.float32) + _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) + # These ops have a reference to 'c' which has a reference to the graph. + # Checks if the functions are being deleted though the graph is referenced from them. + # (see #18292) + _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + + # Call garbage collector to enforce deletion. + make_graphs() + ops.reset_default_graph() + gc.collect() + self.assertEqual(initial_size, script_ops._py_funcs.size()) # ----- Tests for eager_py_func ----- @test_util.run_in_graph_and_eager_modes() |