aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-17 14:24:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 14:28:07 -0700
commit28dd4d9fcbf8cac1008b2ccd2b4be3fa3c25afd1 (patch)
tree68e901eec6d952589b5a69f3be37d7f04dac8373 /tensorflow/python/eager
parent4516558acc9763999b19d1af75ab1fcd6562e4f0 (diff)
Keep only weak references to variables in graph functions
This enables cleanup of the variables referenced in defunned methods of objects when the object is garbage collected. Since one PolymorphicFunction is created per @defun, decorated methods before this change held on to all of the variables referenced in that method for any instance of the class (i.e. variables which should have been object-scoped were scoped to the lifetime of the class definition). Raises an exception if variables used in the function have been deleted when it is called, which means no local variables. PiperOrigin-RevId: 213337256
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py53
-rw-r--r--tensorflow/python/eager/function_test.py54
2 files changed, 71 insertions, 36 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e2874e25b6..4f1a85a274 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@ import collections
import functools
import sys
import threading
+import weakref
import numpy as np
import six
@@ -180,7 +181,7 @@ class FuncGraph(ops.Graph):
self.inputs = []
self.outputs = []
self.structured_outputs = None
- self.variables = []
+ self._weak_variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
@@ -217,6 +218,31 @@ class FuncGraph(ops.Graph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ @property
+ def variables(self):
+ """A list of variables accessed by this FuncGraph.
+
+ Note that functions keep only weak references to variables. Calling the
+ function after a variable it accesses has been deleted is an error.
+
+ Yields:
+ Strong references to variables accessed by this FuncGraph.
+ """
+ for weak_v in self._weak_variables:
+ v = weak_v()
+ if v is None:
+ raise AssertionError(
+ "Called a function referencing variables which have been deleted. "
+ "This likely means that function-local variables were created and "
+ "not referenced elsewhere in the program. This is generally a "
+ "mistake; consider storing variables in an object attribute on "
+ "first call.")
+ yield v
+
+ @variables.setter
+ def variables(self, var_list):
+ self._weak_variables = [weakref.ref(v) for v in var_list]
+
def create_op(
self,
op_type,
@@ -604,11 +630,6 @@ class Function(object):
return self._func_graph
@property
- def variables(self):
- """Returns all variables touched by this function."""
- return self._func_graph.variables
-
- @property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@@ -970,7 +991,16 @@ def _encode_arg(arg):
return tuple(
(_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
else:
- return arg
+ try:
+ # If possible, keep only a weak reference to Python objects. Weak
+ # references hash to the same value as the original object.
+ # TODO(allenl): Clean up dead functions and their cache keys if the cache
+ # gets large. Right now creating objects with a defunned method, calling
+ # the method, and losing a reference to the object in a loop will leak
+ # memory here.
+ return weakref.ref(arg)
+ except TypeError:
+ return arg
def _deterministic_dict_values(dictionary):
@@ -1020,7 +1050,6 @@ class PolymorphicFunction(object):
self._kwds_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
- self._variables = []
self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -1066,12 +1095,6 @@ class PolymorphicFunction(object):
"""Returns the wrapped Python function."""
return self._python_function
- # TODO(akshayka): Remove this property.
- @property
- def variables(self):
- """Returns the union of all variables referenced by cached `Function`s`."""
- return self._variables
-
def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context.
@@ -1238,8 +1261,6 @@ class PolymorphicFunction(object):
func_graph_from_py_func(self._name, self._python_function, args,
kwds, self._input_signature),
self._function_attributes)
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
return graph_function, [
t for t in nest.flatten((args, kwds))
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index c168b6060c..6326a5b45f 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,6 +21,7 @@ import collections
import functools
from multiprocessing.pool import ThreadPool
import sys
+import weakref
import numpy
@@ -74,6 +75,13 @@ class MiniModel(keras_training.Model):
return self.fc(inputs)
+class DefunnedMiniModel(MiniModel):
+
+ @function.defun
+ def call(self, inputs, training=True):
+ return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -140,8 +148,8 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- v = resource_variable_ops.ResourceVariable(1.0)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ return self.v.read_value()
self.assertAllEqual(f(), 1.0)
@@ -399,9 +407,9 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = resource_variable_ops.ResourceVariable(
+ self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
- return v.read_value()
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -415,8 +423,8 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = resource_variable_ops.ResourceVariable(const)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(const)
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -478,13 +486,14 @@ class FunctionTest(test.TestCase):
def testDefunForcesResourceVariables(self):
def variable_creator():
- return variables.Variable(0.0).read_value()
+ self.v = variables.Variable(0.0)
+ return self.v.read_value()
+ self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
- self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], resource_variable_ops.ResourceVariable)
+ self.v, resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1184,13 +1193,11 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo)
x = constant_op.constant([1.0])
- self.assertAllEqual(defined.variables, [])
- _ = defined(x)
- self.assertAllEqual(defined.variables, [v])
+ self.assertEqual(1., self.evaluate(defined(x)))
+ v.assign(2.)
x = constant_op.constant([1.0, 2.0])
- _ = defined(x) # ensure the variables list remains the same
- self.assertAllEqual(defined.variables, [v])
+ self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testPythonFunctionWithDefaultArgs(self):
@@ -1913,10 +1920,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
value = train()
self.assertEqual(value.numpy(), -1.0)
@@ -1943,10 +1950,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
train()
@@ -2133,6 +2140,13 @@ class AutomaticControlDependenciesTest(test.TestCase):
modify_same_flat(nested_input)
+ def testDecoratedMethodVariableCleanup(self):
+ m = DefunnedMiniModel()
+ m(array_ops.ones([1, 2]))
+ weak_variables = weakref.WeakSet(m.variables)
+ self.assertEqual(2, len(weak_variables))
+ del m
+ self.assertEqual([], list(weak_variables))
if __name__ == '__main__':
ops.enable_eager_execution(