aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/py_func_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/py_func_test.py')
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 677253946e..253e43920b 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gc
import re
import numpy as np
@@ -434,13 +435,29 @@ class PyFuncTest(test.TestCase):
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
- for _ in xrange(1000):
- g = ops.Graph()
- with g.as_default():
- c = constant_op.constant([1.], dtypes.float32)
- _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
- _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
- self.assertLess(script_ops._py_funcs.size(), 100)
+ # Delete everything created by previous tests to avoid side effects.
+ ops.reset_default_graph()
+ gc.collect()
+ initial_size = script_ops._py_funcs.size()
+ # Encapsulate the graph generation, so locals can be deleted.
+ def make_graphs():
+ for _ in xrange(1000):
+ g = ops.Graph()
+ with g.as_default():
+ c = constant_op.constant([1.], dtypes.float32)
+ _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
+ _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
+ # These ops have a reference to 'c' which has a reference to the graph.
+ # Checks if the functions are being deleted though the graph is referenced from them.
+ # (see #18292)
+ _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+ _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+
+ # Call garbage collector to enforce deletion.
+ make_graphs()
+ ops.reset_default_graph()
+ gc.collect()
+ self.assertEqual(initial_size, script_ops._py_funcs.size())
# ----- Tests for eager_py_func -----
@test_util.run_in_graph_and_eager_modes()