diff options
author | Igor Ganichev <iga@google.com> | 2018-09-12 13:32:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 13:37:15 -0700 |
commit | 52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch) | |
tree | 54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/keras/backend.py | |
parent | 5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff) |
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which
held onto captured tensors leaving garbage around.
This change also adds a test to catch garbage like this in the future.
To make the test work, I needed to manually breakup some reference
cycles caused by OrderedDicts. We should probably have a custom impl
of OrderedDict similar to the one in Python3 and avoid these issues.
PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 7768caeaf0..529b07dc12 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -73,7 +73,16 @@ _SESSION = None # This dictionary holds a mapping {graph: learning_phase}. # A learning phase is a bool tensor used to run Keras models in # either train mode (learning_phase == 1) or test mode (learning_phase == 0). -_GRAPH_LEARNING_PHASES = {} +_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary() + + +# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES. +# We keep a separate reference to it to make sure it does not get removed from +# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a +# string because strings are not weakly-referencable. +class _DummyEagerGraph(object): + pass +_DUMMY_EAGER_GRAPH = _DummyEagerGraph() # This boolean flag can be set to True to leave variable initialization # up to the user. @@ -96,11 +105,11 @@ _LOCAL_DEVICES = None # This dictionary holds a mapping between a graph and variables to initialize # in the graph. -_GRAPH_VARIABLES = {} +_GRAPH_VARIABLES = weakref.WeakKeyDictionary() # This dictionary holds a mapping between a graph and TF optimizers created in # the graph. -_GRAPH_TF_OPTIMIZERS = {} +_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary() @tf_export('keras.backend.backend') @@ -359,10 +368,10 @@ def learning_phase(): Learning phase (scalar integer tensor or Python integer). """ if context.executing_eagerly(): - if 'eager' not in _GRAPH_LEARNING_PHASES: + if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES: # Fallback to inference mode as default. return 0 - return _GRAPH_LEARNING_PHASES['eager'] + return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] graph = ops.get_default_graph() if graph not in _GRAPH_LEARNING_PHASES: @@ -386,7 +395,7 @@ def set_learning_phase(value): if value not in {0, 1}: raise ValueError('Expected learning phase to be 0 or 1.') if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES['eager'] = value + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value else: _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value @@ -415,7 +424,7 @@ def learning_phase_scope(value): finally: # Restore learning phase to initial value. if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES['eager'] = previous_value + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value else: _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value |