diff options
author | Igor Ganichev <iga@google.com> | 2018-09-12 13:32:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 13:37:15 -0700 |
commit | 52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch) | |
tree | 54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/framework | |
parent | 5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff) |
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which
held onto captured tensors leaving garbage around.
This change also adds a test to catch garbage like this in the future.
To make the test work, I needed to manually breakup some reference
cycles caused by OrderedDicts. We should probably have a custom impl
of OrderedDict similar to the one in Python3 and avoid these issues.
PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/ops.py | 19 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 40 |
2 files changed, 44 insertions, 15 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 75678cbc01..343f52fe8f 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils from tensorflow.python.util import deprecation from tensorflow.python.util import function_utils from tensorflow.python.util import lock_util +from tensorflow.python.util import memory from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_stack from tensorflow.python.util.deprecation import deprecated_args @@ -5824,23 +5825,11 @@ def dismantle_graph(graph): graph: A `Graph` object to destroy. Neither it nor any of its ops are usable after this function runs. """ - # pylint: disable=protected-access - # OrderedDict, constructed on Graph creation, makes a simple reference loop - # and hides it in an __attribute in some Python versions. We don't need to - # throw an error if we can't find it, but if we do find it we can break the - # loop to avoid creating work for the garbage collector. - graph_operations = graph.get_operations() - problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None) - # pylint: enable=protected-access - if problematic_cycle: - try: - del problematic_cycle[0][:] - except TypeError: - # This is probably not one of the problematic Python versions. Continue - # with the rest of our cleanup. - pass + memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access + # Now clean up Operation<->Graph reference cycles by clearing all of the # attributes for the Graph and its ops. + graph_operations = graph.get_operations() for op in graph_operations: op.__dict__ = {} graph.__dict__ = {} diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 6a2c897f3f..1cc3bb4628 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import memory from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare @@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version): with graph.as_default(): importer.import_graph_def(graph_def) assert graph.graph_def_versions.producer, producer_version + + +def dismantle_func_graph(func_graph): + """Removes reference cycles in `func_graph` FuncGraph. + + Helpful for making sure the garbage collector doesn't need to run when + the FuncGraph goes out of scope, e.g. in tests using defun with + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). + + Args: + func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable + after this function. + """ + # TODO(b/115366440): Delete this method when a custom OrderedDict is added. + # Clearing captures using clear() leaves some cycles around. + while func_graph.captures: + func_graph.captures.popitem() + memory.dismantle_ordered_dict(func_graph.captures) + ops.dismantle_graph(func_graph) + + +def dismantle_polymorphic_function(func): + """Removes reference cycles in PolymorphicFunction `func`. + + Helpful for making sure the garbage collector doesn't need to run when + PolymorphicFunction goes out of scope, e.g. in tests using defun with + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). + + Args: + func: A `PolymorphicFunction` object to destroy. `func` is unusable + after this function. + """ + # TODO(b/115366440): Delete this method when a custom OrderedDict is added + cache = func._function_cache # pylint: disable=protected-access + for concrete_func in cache.values(): + dismantle_func_graph(concrete_func.graph) + while cache: + cache.popitem() + memory.dismantle_ordered_dict(cache) |