aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-25 10:30:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:36:47 -0700
commit410905d8e8af12e928031aa026683e43b665c8ae (patch)
treec1d95a5382bf9daa5a2532f23c3e2a5800705b0e /tensorflow/python/keras/backend.py
parent83763d0be3c664f84a776a8c69d49846fbfd1b9e (diff)
Keep only weak references to TensorFlow Optimizer objects in tf.keras
I don't think this annoyed anyone else yet, it's just a nit I noticed while making sure variables can be garbage collected when tracked via tf.keras. PiperOrigin-RevId: 214462105
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index a46f9edb1e..4589c821e5 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -695,10 +695,8 @@ def track_tf_optimizer(tf_optimizer):
if context.executing_eagerly():
return
graph = ops.get_default_graph()
- if graph not in _GRAPH_TF_OPTIMIZERS:
- _GRAPH_TF_OPTIMIZERS[graph] = set()
- _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
-
+ optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
+ optimizers.add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""