diff options
author | 2018-09-04 14:03:08 -0700 | |
---|---|---|
committer | 2018-09-04 14:26:01 -0700 | |
commit | d29eb6d1c9d1e4b2f601864f53878674f219fe6f (patch) | |
tree | 4292df644dad3f34e813ddb742c31aa730e64fba /tensorflow/python/framework | |
parent | 06e8109af2e5ae5bc149e25fc64fbf66d6c8b817 (diff) |
Remove reference cycles when constructing distribution objects
self -> _parameters -> self cycles were creating work for Python's garbage collector in training loops, where Distribution objects may be created repeatedly when executing eagerly. This CL just fixes that narrow memory issue; I'm not convinced dict(locals()) is super efficient, so we may want to follow up on that for performance.
Adds a few unit tests tests with run_test_in_graph_and_eager_modes(assert_no_eager_garbage=True). It'd be nice to expand this coverage over time.
Includes a small test_util simplification to support this (TFP tests don't like reset_default_graph for some reason). Testing for cycles in the TFP repo will need to wait on the Normal changes from the TF repo syncing.
PiperOrigin-RevId: 211520394
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b5388ad0b2..3b63e49a84 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -535,15 +535,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - if context.executing_eagerly(): - f(self, **kwargs) - ops.reset_default_graph() - else: - # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the graph key so optimizers behave. - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): - ops.get_default_graph()._graph_key = outside_graph_key + outside_executed_eagerly = context.executing_eagerly() + # Run the test in a new graph so that collections get cleared when it's + # done, but inherit the graph key so optimizers behave. + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + if outside_executed_eagerly: + with context.eager_mode(): + f(self, **kwargs) + else: f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. |