aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-09-12 13:32:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 13:37:15 -0700
commit52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch)
tree54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/framework
parent5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff)
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/ops.py19
-rw-r--r--tensorflow/python/framework/test_util.py40
2 files changed, 44 insertions, 15 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 75678cbc01..343f52fe8f 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
+from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
@@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
after this function runs.
"""
- # pylint: disable=protected-access
- # OrderedDict, constructed on Graph creation, makes a simple reference loop
- # and hides it in an __attribute in some Python versions. We don't need to
- # throw an error if we can't find it, but if we do find it we can break the
- # loop to avoid creating work for the garbage collector.
- graph_operations = graph.get_operations()
- problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
- # pylint: enable=protected-access
- if problematic_cycle:
- try:
- del problematic_cycle[0][:]
- except TypeError:
- # This is probably not one of the problematic Python versions. Continue
- # with the rest of our cleanup.
- pass
+ memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
+
# Now clean up Operation<->Graph reference cycles by clearing all of the
# attributes for the Graph and its ops.
+ graph_operations = graph.get_operations()
for op in graph_operations:
op.__dict__ = {}
graph.__dict__ = {}
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 6a2c897f3f..1cc3bb4628 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
@@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version):
with graph.as_default():
importer.import_graph_def(graph_def)
assert graph.graph_def_versions.producer, producer_version
+
+
+def dismantle_func_graph(func_graph):
+ """Removes reference cycles in `func_graph` FuncGraph.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ the FuncGraph goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+ # Clearing captures using clear() leaves some cycles around.
+ while func_graph.captures:
+ func_graph.captures.popitem()
+ memory.dismantle_ordered_dict(func_graph.captures)
+ ops.dismantle_graph(func_graph)
+
+
+def dismantle_polymorphic_function(func):
+ """Removes reference cycles in PolymorphicFunction `func`.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ PolymorphicFunction goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func: A `PolymorphicFunction` object to destroy. `func` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added
+ cache = func._function_cache # pylint: disable=protected-access
+ for concrete_func in cache.values():
+ dismantle_func_graph(concrete_func.graph)
+ while cache:
+ cache.popitem()
+ memory.dismantle_ordered_dict(cache)