diff options
author | Allen Lavoie <allenl@google.com> | 2018-09-25 10:30:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 10:36:47 -0700 |
commit | 410905d8e8af12e928031aa026683e43b665c8ae (patch) | |
tree | c1d95a5382bf9daa5a2532f23c3e2a5800705b0e /tensorflow/python/keras | |
parent | 83763d0be3c664f84a776a8c69d49846fbfd1b9e (diff) |
Keep only weak references to TensorFlow Optimizer objects in tf.keras
I don't think this annoyed anyone else yet, it's just a nit I noticed while making sure variables can be garbage collected when tracked via tf.keras.
PiperOrigin-RevId: 214462105
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/backend.py | 6 | ||||
-rw-r--r-- | tensorflow/python/keras/optimizers_test.py | 17 |
2 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index a46f9edb1e..4589c821e5 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -695,10 +695,8 @@ def track_tf_optimizer(tf_optimizer): if context.executing_eagerly(): return graph = ops.get_default_graph() - if graph not in _GRAPH_TF_OPTIMIZERS: - _GRAPH_TF_OPTIMIZERS[graph] = set() - _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer) - + optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet()) + optimizers.add(tf_optimizer) def track_variable(v): """Tracks the given variable for initialization.""" diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py index 8d7493462e..9664f09fff 100644 --- a/tensorflow/python/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc +import weakref + import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test @@ -156,6 +160,19 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + def test_optimizer_garbage_collection(self): + graph = ops.Graph() + with graph.as_default(): + optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01)) + keras.backend.track_tf_optimizer(optimizer) + optimizer_weak = weakref.ref(optimizer) + graph_weak = weakref.ref(graph) + del graph, optimizer + gc.collect() + # Check that the weak references are dead now. + self.assertIs(graph_weak(), None) + self.assertIs(optimizer_weak(), None) + @test_util.run_in_graph_and_eager_modes def test_tfoptimizer_iterations(self): with self.cached_session(): |