aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-09-12 13:32:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 13:37:15 -0700
commit52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch)
tree54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/keras/backend.py
parent5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff)
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 7768caeaf0..529b07dc12 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -73,7 +73,16 @@ _SESSION = None
# This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
-_GRAPH_LEARNING_PHASES = {}
+_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
+
+
+# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
+# We keep a separate reference to it to make sure it does not get removed from
+# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
+# string because strings are not weakly-referencable.
+class _DummyEagerGraph(object):
+ pass
+_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
# This boolean flag can be set to True to leave variable initialization
# up to the user.
@@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
-_GRAPH_VARIABLES = {}
+_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
-_GRAPH_TF_OPTIMIZERS = {}
+_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
@tf_export('keras.backend.backend')
@@ -359,10 +368,10 @@ def learning_phase():
Learning phase (scalar integer tensor or Python integer).
"""
if context.executing_eagerly():
- if 'eager' not in _GRAPH_LEARNING_PHASES:
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default.
return 0
- return _GRAPH_LEARNING_PHASES['eager']
+ return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES:
@@ -386,7 +395,7 @@ def set_learning_phase(value):
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = value
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@@ -415,7 +424,7 @@ def learning_phase_scope(value):
finally:
# Restore learning phase to initial value.
if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = previous_value
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value