aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-04 14:03:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 14:26:01 -0700
commitd29eb6d1c9d1e4b2f601864f53878674f219fe6f (patch)
tree4292df644dad3f34e813ddb742c31aa730e64fba /tensorflow/python/framework
parent06e8109af2e5ae5bc149e25fc64fbf66d6c8b817 (diff)
Remove reference cycles when constructing distribution objects
self -> _parameters -> self cycles were creating work for Python's garbage collector in training loops, where Distribution objects may be created repeatedly when executing eagerly. This CL just fixes that narrow memory issue; I'm not convinced dict(locals()) is super efficient, so we may want to follow up on that for performance. Adds a few unit tests tests with run_test_in_graph_and_eager_modes(assert_no_eager_garbage=True). It'd be nice to expand this coverage over time. Includes a small test_util simplification to support this (TFP tests don't like reset_default_graph for some reason). Testing for cycles in the TFP repo will need to wait on the Normal changes from the TF repo syncing. PiperOrigin-RevId: 211520394
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b5388ad0b2..3b63e49a84 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -535,15 +535,16 @@ def assert_no_new_tensors(f):
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
- if context.executing_eagerly():
- f(self, **kwargs)
- ops.reset_default_graph()
- else:
- # Run the test in a new graph so that collections get cleared when it's
- # done, but inherit the graph key so optimizers behave.
- outside_graph_key = ops.get_default_graph()._graph_key
- with ops.Graph().as_default():
- ops.get_default_graph()._graph_key = outside_graph_key
+ outside_executed_eagerly = context.executing_eagerly()
+ # Run the test in a new graph so that collections get cleared when it's
+ # done, but inherit the graph key so optimizers behave.
+ outside_graph_key = ops.get_default_graph()._graph_key
+ with ops.Graph().as_default():
+ ops.get_default_graph()._graph_key = outside_graph_key
+ if outside_executed_eagerly:
+ with context.eager_mode():
+ f(self, **kwargs)
+ else:
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.