aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-25 10:30:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:36:47 -0700
commit410905d8e8af12e928031aa026683e43b665c8ae (patch)
treec1d95a5382bf9daa5a2532f23c3e2a5800705b0e
parent83763d0be3c664f84a776a8c69d49846fbfd1b9e (diff)
Keep only weak references to TensorFlow Optimizer objects in tf.keras
I don't think this annoyed anyone else yet, it's just a nit I noticed while making sure variables can be garbage collected when tracked via tf.keras. PiperOrigin-RevId: 214462105
-rw-r--r--tensorflow/python/keras/backend.py6
-rw-r--r--tensorflow/python/keras/optimizers_test.py17
2 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index a46f9edb1e..4589c821e5 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -695,10 +695,8 @@ def track_tf_optimizer(tf_optimizer):
if context.executing_eagerly():
return
graph = ops.get_default_graph()
- if graph not in _GRAPH_TF_OPTIMIZERS:
- _GRAPH_TF_OPTIMIZERS[graph] = set()
- _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
-
+ optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
+ optimizers.add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 8d7493462e..9664f09fff 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gc
+import weakref
+
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -156,6 +160,19 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_optimizer_garbage_collection(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
+ keras.backend.track_tf_optimizer(optimizer)
+ optimizer_weak = weakref.ref(optimizer)
+ graph_weak = weakref.ref(graph)
+ del graph, optimizer
+ gc.collect()
+ # Check that the weak references are dead now.
+ self.assertIs(graph_weak(), None)
+ self.assertIs(optimizer_weak(), None)
+
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
with self.cached_session():