diff options
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 53 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 54 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/keras/backend.py | 4 | ||||
-rw-r--r-- | tensorflow/python/training/gradient_descent_test.py | 10 |
6 files changed, 95 insertions, 50 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index c6894e9013..f51e543624 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -1271,7 +1271,17 @@ class MirroredStrategyDefunTest(test.TestCase): self.evaluate(device_result)) for defun in defuns: - self.assertEqual(set(mock_model.variables), set(defun.variables)) + # PolymorphicFunctions are specialized to the current device stack, so + # call_for_each has one trace per device. To check that the expected set + # of variables was accessed on each trace, we first retrieve each + # device-specific graph function. + per_device_graph_functions = dist.call_for_each_tower( + defun.get_concrete_function, + mock_model, *inputs, run_concurrently=False) + for device in devices: + graph_function = per_device_graph_functions.get(device=device) + self.assertEqual(set(mock_model.variables), + set(graph_function.graph.variables)) @test_util.run_in_graph_and_eager_modes() def testVariableInDefun(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index e2874e25b6..4f1a85a274 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -23,6 +23,7 @@ import collections import functools import sys import threading +import weakref import numpy as np import six @@ -180,7 +181,7 @@ class FuncGraph(ops.Graph): self.inputs = [] self.outputs = [] self.structured_outputs = None - self.variables = [] + self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() @@ -217,6 +218,31 @@ class FuncGraph(ops.Graph): self._graph_key = graph._graph_key # pylint: enable=protected-access + @property + def variables(self): + """A list of variables accessed by this FuncGraph. + + Note that functions keep only weak references to variables. Calling the + function after a variable it accesses has been deleted is an error. + + Yields: + Strong references to variables accessed by this FuncGraph. + """ + for weak_v in self._weak_variables: + v = weak_v() + if v is None: + raise AssertionError( + "Called a function referencing variables which have been deleted. " + "This likely means that function-local variables were created and " + "not referenced elsewhere in the program. This is generally a " + "mistake; consider storing variables in an object attribute on " + "first call.") + yield v + + @variables.setter + def variables(self, var_list): + self._weak_variables = [weakref.ref(v) for v in var_list] + def create_op( self, op_type, @@ -604,11 +630,6 @@ class Function(object): return self._func_graph @property - def variables(self): - """Returns all variables touched by this function.""" - return self._func_graph.variables - - @property def inputs(self): """Returns tensors in `self.graph` corresponding to arguments.""" return self._func_graph.inputs @@ -970,7 +991,16 @@ def _encode_arg(arg): return tuple( (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg)) else: - return arg + try: + # If possible, keep only a weak reference to Python objects. Weak + # references hash to the same value as the original object. + # TODO(allenl): Clean up dead functions and their cache keys if the cache + # gets large. Right now creating objects with a defunned method, calling + # the method, and losing a reference to the object in a loop will leak + # memory here. + return weakref.ref(arg) + except TypeError: + return arg def _deterministic_dict_values(dictionary): @@ -1020,7 +1050,6 @@ class PolymorphicFunction(object): self._kwds_to_include = {} self._name = name self._function_cache = collections.OrderedDict() - self._variables = [] self._function_attributes = attributes or {} self._lock = threading.Lock() @@ -1066,12 +1095,6 @@ class PolymorphicFunction(object): """Returns the wrapped Python function.""" return self._python_function - # TODO(akshayka): Remove this property. - @property - def variables(self): - """Returns the union of all variables referenced by cached `Function`s`.""" - return self._variables - def get_concrete_function(self, *args, **kwargs): """Returns a `Function` object specialized to inputs and execution context. @@ -1238,8 +1261,6 @@ class PolymorphicFunction(object): func_graph_from_py_func(self._name, self._python_function, args, kwds, self._input_signature), self._function_attributes) - self._variables.extend( - [v for v in graph_function.variables if v not in self._variables]) self._function_cache[cache_key] = graph_function return graph_function, [ t for t in nest.flatten((args, kwds)) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index c168b6060c..6326a5b45f 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -21,6 +21,7 @@ import collections import functools from multiprocessing.pool import ThreadPool import sys +import weakref import numpy @@ -74,6 +75,13 @@ class MiniModel(keras_training.Model): return self.fc(inputs) +class DefunnedMiniModel(MiniModel): + + @function.defun + def call(self, inputs, training=True): + return super(DefunnedMiniModel, self).call(inputs, training=training) + + @test_util.with_c_shapes class FunctionTest(test.TestCase): @@ -140,8 +148,8 @@ class FunctionTest(test.TestCase): @function.defun def f(): - v = resource_variable_ops.ResourceVariable(1.0) - return v.read_value() + self.v = resource_variable_ops.ResourceVariable(1.0) + return self.v.read_value() self.assertAllEqual(f(), 1.0) @@ -399,9 +407,9 @@ class FunctionTest(test.TestCase): @function.defun def tensor_init(): - v = resource_variable_ops.ResourceVariable( + self.v = resource_variable_ops.ResourceVariable( lambda: constant_op.constant(2.0)) - return v.read_value() + return self.v.read_value() value = tensor_init() if not context.executing_eagerly(): @@ -415,8 +423,8 @@ class FunctionTest(test.TestCase): def tensor_init(): with ops.init_scope(): const = constant_op.constant(2.0) - v = resource_variable_ops.ResourceVariable(const) - return v.read_value() + self.v = resource_variable_ops.ResourceVariable(const) + return self.v.read_value() value = tensor_init() if not context.executing_eagerly(): @@ -478,13 +486,14 @@ class FunctionTest(test.TestCase): def testDefunForcesResourceVariables(self): def variable_creator(): - return variables.Variable(0.0).read_value() + self.v = variables.Variable(0.0) + return self.v.read_value() + self.v = None defined = function.defun(variable_creator) defined() # Create the variable. - self.assertEqual(len(defined.variables), 1) self.assertIsInstance( - defined.variables[0], resource_variable_ops.ResourceVariable) + self.v, resource_variable_ops.ResourceVariable) def testDefunDifferentiable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -1184,13 +1193,11 @@ class FunctionTest(test.TestCase): defined = function.defun(foo) x = constant_op.constant([1.0]) - self.assertAllEqual(defined.variables, []) - _ = defined(x) - self.assertAllEqual(defined.variables, [v]) + self.assertEqual(1., self.evaluate(defined(x))) + v.assign(2.) x = constant_op.constant([1.0, 2.0]) - _ = defined(x) # ensure the variables list remains the same - self.assertAllEqual(defined.variables, [v]) + self.assertAllEqual([2., 4.], self.evaluate(defined(x))) def testPythonFunctionWithDefaultArgs(self): @@ -1913,10 +1920,10 @@ class AutomaticControlDependenciesTest(test.TestCase): @function.defun def train(): - v = resource_variable_ops.ResourceVariable(1.0) - grad = backprop.implicit_grad(loss)(v) + self.v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(self.v) optimizer.apply_gradients(grad) - return v.read_value() + return self.v.read_value() value = train() self.assertEqual(value.numpy(), -1.0) @@ -1943,10 +1950,10 @@ class AutomaticControlDependenciesTest(test.TestCase): @function.defun def train(): - v = resource_variable_ops.ResourceVariable(1.0) - grad = backprop.implicit_grad(loss)(v) + self.v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(self.v) optimizer.apply_gradients(grad) - return v.read_value() + return self.v.read_value() train() @@ -2133,6 +2140,13 @@ class AutomaticControlDependenciesTest(test.TestCase): modify_same_flat(nested_input) + def testDecoratedMethodVariableCleanup(self): + m = DefunnedMiniModel() + m(array_ops.ones([1, 2])) + weak_variables = weakref.WeakSet(m.variables) + self.assertEqual(2, len(weak_variables)) + del m + self.assertEqual([], list(weak_variables)) if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index d59adf3d48..c3a3437743 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2142,8 +2142,8 @@ class InitScopeTest(test_util.TensorFlowTestCase): def function_with_variables(): with ops.init_scope(): - v = resource_variable_ops.ResourceVariable(3) - return v.assign_add(1) + self.v = resource_variable_ops.ResourceVariable(3) + return self.v.assign_add(1) with context.eager_mode(): # Each invocation of function_with_variables recreates a variable. @@ -2188,13 +2188,13 @@ class InitScopeTest(test_util.TensorFlowTestCase): def inner_function(): with ops.init_scope(): - v = resource_variable_ops.ResourceVariable(1) - return v.assign_add(2) + self.v = resource_variable_ops.ResourceVariable(1) + return self.v.assign_add(2) def outer_function(inner=None): with ops.init_scope(): - v0 = resource_variable_ops.ResourceVariable(0) - return v0.assign_add(1) + inner() + self.v0 = resource_variable_ops.ResourceVariable(0) + return self.v0.assign_add(1) + inner() with context.eager_mode(): # Each invocation of outer_function recreates variables. diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 5e1722ba20..60ed8e8c8a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -696,14 +696,14 @@ def track_variable(v): return graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph() if graph not in _GRAPH_VARIABLES: - _GRAPH_VARIABLES[graph] = set() + _GRAPH_VARIABLES[graph] = weakref.WeakSet() _GRAPH_VARIABLES[graph].add(v) def _get_variables(graph=None): """Returns variables corresponding to the given graph for initialization.""" assert not context.executing_eagerly() - variables = _GRAPH_VARIABLES.get(graph, set()) + variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet()) for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()): variables.update(opt.optimizer.variables()) return variables diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py index 56d82a5b88..1ddea598e5 100644 --- a/tensorflow/python/training/gradient_descent_test.py +++ b/tensorflow/python/training/gradient_descent_test.py @@ -252,12 +252,12 @@ class GradientDescentOptimizerTest(test.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(1.0) def step(): - v = resource_variable_ops.ResourceVariable(1.0) + self.v = resource_variable_ops.ResourceVariable(1.0) with backprop.GradientTape() as tape: - loss = v ** 2 - grad = tape.gradient(loss, v) - optimizer.apply_gradients([(grad, v)]) - return v.read_value() + loss = self.v ** 2 + grad = tape.gradient(loss, self.v) + optimizer.apply_gradients([(grad, self.v)]) + return self.v.read_value() compiled_step = function.defun(step) |