aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py12
-rw-r--r--tensorflow/python/eager/function.py53
-rw-r--r--tensorflow/python/eager/function_test.py54
-rw-r--r--tensorflow/python/framework/ops_test.py12
-rw-r--r--tensorflow/python/keras/backend.py4
-rw-r--r--tensorflow/python/training/gradient_descent_test.py10
6 files changed, 95 insertions, 50 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index c6894e9013..f51e543624 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -1271,7 +1271,17 @@ class MirroredStrategyDefunTest(test.TestCase):
self.evaluate(device_result))
for defun in defuns:
- self.assertEqual(set(mock_model.variables), set(defun.variables))
+ # PolymorphicFunctions are specialized to the current device stack, so
+ # call_for_each has one trace per device. To check that the expected set
+ # of variables was accessed on each trace, we first retrieve each
+ # device-specific graph function.
+ per_device_graph_functions = dist.call_for_each_tower(
+ defun.get_concrete_function,
+ mock_model, *inputs, run_concurrently=False)
+ for device in devices:
+ graph_function = per_device_graph_functions.get(device=device)
+ self.assertEqual(set(mock_model.variables),
+ set(graph_function.graph.variables))
@test_util.run_in_graph_and_eager_modes()
def testVariableInDefun(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e2874e25b6..4f1a85a274 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@ import collections
import functools
import sys
import threading
+import weakref
import numpy as np
import six
@@ -180,7 +181,7 @@ class FuncGraph(ops.Graph):
self.inputs = []
self.outputs = []
self.structured_outputs = None
- self.variables = []
+ self._weak_variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
@@ -217,6 +218,31 @@ class FuncGraph(ops.Graph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ @property
+ def variables(self):
+ """A list of variables accessed by this FuncGraph.
+
+ Note that functions keep only weak references to variables. Calling the
+ function after a variable it accesses has been deleted is an error.
+
+ Yields:
+ Strong references to variables accessed by this FuncGraph.
+ """
+ for weak_v in self._weak_variables:
+ v = weak_v()
+ if v is None:
+ raise AssertionError(
+ "Called a function referencing variables which have been deleted. "
+ "This likely means that function-local variables were created and "
+ "not referenced elsewhere in the program. This is generally a "
+ "mistake; consider storing variables in an object attribute on "
+ "first call.")
+ yield v
+
+ @variables.setter
+ def variables(self, var_list):
+ self._weak_variables = [weakref.ref(v) for v in var_list]
+
def create_op(
self,
op_type,
@@ -604,11 +630,6 @@ class Function(object):
return self._func_graph
@property
- def variables(self):
- """Returns all variables touched by this function."""
- return self._func_graph.variables
-
- @property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@@ -970,7 +991,16 @@ def _encode_arg(arg):
return tuple(
(_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
else:
- return arg
+ try:
+ # If possible, keep only a weak reference to Python objects. Weak
+ # references hash to the same value as the original object.
+ # TODO(allenl): Clean up dead functions and their cache keys if the cache
+ # gets large. Right now creating objects with a defunned method, calling
+ # the method, and losing a reference to the object in a loop will leak
+ # memory here.
+ return weakref.ref(arg)
+ except TypeError:
+ return arg
def _deterministic_dict_values(dictionary):
@@ -1020,7 +1050,6 @@ class PolymorphicFunction(object):
self._kwds_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
- self._variables = []
self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -1066,12 +1095,6 @@ class PolymorphicFunction(object):
"""Returns the wrapped Python function."""
return self._python_function
- # TODO(akshayka): Remove this property.
- @property
- def variables(self):
- """Returns the union of all variables referenced by cached `Function`s`."""
- return self._variables
-
def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context.
@@ -1238,8 +1261,6 @@ class PolymorphicFunction(object):
func_graph_from_py_func(self._name, self._python_function, args,
kwds, self._input_signature),
self._function_attributes)
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
return graph_function, [
t for t in nest.flatten((args, kwds))
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index c168b6060c..6326a5b45f 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,6 +21,7 @@ import collections
import functools
from multiprocessing.pool import ThreadPool
import sys
+import weakref
import numpy
@@ -74,6 +75,13 @@ class MiniModel(keras_training.Model):
return self.fc(inputs)
+class DefunnedMiniModel(MiniModel):
+
+ @function.defun
+ def call(self, inputs, training=True):
+ return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -140,8 +148,8 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- v = resource_variable_ops.ResourceVariable(1.0)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ return self.v.read_value()
self.assertAllEqual(f(), 1.0)
@@ -399,9 +407,9 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = resource_variable_ops.ResourceVariable(
+ self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
- return v.read_value()
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -415,8 +423,8 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = resource_variable_ops.ResourceVariable(const)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(const)
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -478,13 +486,14 @@ class FunctionTest(test.TestCase):
def testDefunForcesResourceVariables(self):
def variable_creator():
- return variables.Variable(0.0).read_value()
+ self.v = variables.Variable(0.0)
+ return self.v.read_value()
+ self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
- self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], resource_variable_ops.ResourceVariable)
+ self.v, resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1184,13 +1193,11 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo)
x = constant_op.constant([1.0])
- self.assertAllEqual(defined.variables, [])
- _ = defined(x)
- self.assertAllEqual(defined.variables, [v])
+ self.assertEqual(1., self.evaluate(defined(x)))
+ v.assign(2.)
x = constant_op.constant([1.0, 2.0])
- _ = defined(x) # ensure the variables list remains the same
- self.assertAllEqual(defined.variables, [v])
+ self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testPythonFunctionWithDefaultArgs(self):
@@ -1913,10 +1920,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
value = train()
self.assertEqual(value.numpy(), -1.0)
@@ -1943,10 +1950,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
train()
@@ -2133,6 +2140,13 @@ class AutomaticControlDependenciesTest(test.TestCase):
modify_same_flat(nested_input)
+ def testDecoratedMethodVariableCleanup(self):
+ m = DefunnedMiniModel()
+ m(array_ops.ones([1, 2]))
+ weak_variables = weakref.WeakSet(m.variables)
+ self.assertEqual(2, len(weak_variables))
+ del m
+ self.assertEqual([], list(weak_variables))
if __name__ == '__main__':
ops.enable_eager_execution(
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index d59adf3d48..c3a3437743 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2142,8 +2142,8 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def function_with_variables():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(3)
- return v.assign_add(1)
+ self.v = resource_variable_ops.ResourceVariable(3)
+ return self.v.assign_add(1)
with context.eager_mode():
# Each invocation of function_with_variables recreates a variable.
@@ -2188,13 +2188,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def inner_function():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(1)
- return v.assign_add(2)
+ self.v = resource_variable_ops.ResourceVariable(1)
+ return self.v.assign_add(2)
def outer_function(inner=None):
with ops.init_scope():
- v0 = resource_variable_ops.ResourceVariable(0)
- return v0.assign_add(1) + inner()
+ self.v0 = resource_variable_ops.ResourceVariable(0)
+ return self.v0.assign_add(1) + inner()
with context.eager_mode():
# Each invocation of outer_function recreates variables.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 5e1722ba20..60ed8e8c8a 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -696,14 +696,14 @@ def track_variable(v):
return
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
if graph not in _GRAPH_VARIABLES:
- _GRAPH_VARIABLES[graph] = set()
+ _GRAPH_VARIABLES[graph] = weakref.WeakSet()
_GRAPH_VARIABLES[graph].add(v)
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
- variables = _GRAPH_VARIABLES.get(graph, set())
+ variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables.update(opt.optimizer.variables())
return variables
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 56d82a5b88..1ddea598e5 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -252,12 +252,12 @@ class GradientDescentOptimizerTest(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(1.0)
def step():
- v = resource_variable_ops.ResourceVariable(1.0)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
with backprop.GradientTape() as tape:
- loss = v ** 2
- grad = tape.gradient(loss, v)
- optimizer.apply_gradients([(grad, v)])
- return v.read_value()
+ loss = self.v ** 2
+ grad = tape.gradient(loss, self.v)
+ optimizer.apply_gradients([(grad, self.v)])
+ return self.v.read_value()
compiled_step = function.defun(step)