aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-17 14:24:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 14:28:07 -0700
commit28dd4d9fcbf8cac1008b2ccd2b4be3fa3c25afd1 (patch)
tree68e901eec6d952589b5a69f3be37d7f04dac8373 /tensorflow/python/keras/backend.py
parent4516558acc9763999b19d1af75ab1fcd6562e4f0 (diff)
Keep only weak references to variables in graph functions
This enables cleanup of the variables referenced in defunned methods of objects when the object is garbage collected. Since one PolymorphicFunction is created per @defun, decorated methods before this change held on to all of the variables referenced in that method for any instance of the class (i.e. variables which should have been object-scoped were scoped to the lifetime of the class definition). Raises an exception if variables used in the function have been deleted when it is called, which means no local variables. PiperOrigin-RevId: 213337256
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 5e1722ba20..60ed8e8c8a 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -696,14 +696,14 @@ def track_variable(v):
return
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
if graph not in _GRAPH_VARIABLES:
- _GRAPH_VARIABLES[graph] = set()
+ _GRAPH_VARIABLES[graph] = weakref.WeakSet()
_GRAPH_VARIABLES[graph].add(v)
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
- variables = _GRAPH_VARIABLES.get(graph, set())
+ variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables.update(opt.optimizer.variables())
return variables