From 5f69248a692f7b47ea11930621f4f19d0397fe8c Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Tue, 9 Oct 2018 15:07:47 -0700 Subject: Make defun work under distributed strategies. The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530 --- tensorflow/python/eager/backprop_test.py | 24 +++++++++++++++ tensorflow/python/eager/function.py | 53 +++----------------------------- tensorflow/python/eager/tape.py | 31 +++++++++++++++++-- 3 files changed, 56 insertions(+), 52 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 7e5c9f3cb6..b1b20fafd2 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -258,6 +258,30 @@ class BackpropTest(test.TestCase): loss += v * v self.assertAllEqual(t.gradient(loss, v), 2.0) + def testAutomaticWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + loss = v * v + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + loss += v * v + self.assertAllEqual([v], t.watched_variables()) + + def testExplicitWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + @test_util.assert_no_new_tensors def testGradientNone(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index ff138cad1e..f1a63adce1 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -202,6 +201,7 @@ class FuncGraph(ops.Graph): # from the default graph even in eager mode. Maybe it should be part of the # eager context? self._distribution_strategy_stack = graph._distribution_strategy_stack + self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key @@ -563,17 +563,6 @@ class Function(object): self._func_graph.inputs, self._func_graph.outputs, self._attrs) self._backward_graph_function = None - # Map holding distributed variables, keyed by resource handle tensors. - self._distributed_variables = {} - strategy = distribution_strategy_context.get_distribution_strategy() - for variable in self._func_graph.variables: - # If variable is not distributed, unwrap returns [variable]. - component_variables = strategy.unwrap(variable) - # Only update the dictionary when the variable is actually distributed. - if (len(component_variables) > 1 or component_variables[0] != variable): - for component_variable in component_variables: - self._distributed_variables[component_variable.handle] = variable - def __call__(self, *args): """Executes the wrapped function. @@ -602,7 +591,6 @@ class Function(object): if v.trainable: tape.variable_accessed(v) - captures = self._resolve_captured_inputs() tensor_inputs = [] for i, arg in enumerate(nest.flatten(args)): if isinstance(arg, resource_variable_ops.ResourceVariable): @@ -615,9 +603,10 @@ class Function(object): raise ValueError("All inputs to `Function`s must be Tensors; " "on invocation of %s, the %d-th input (%s) was not a " "Tensor." % (self._func_graph.name, i, str(arg))) - args = tensor_inputs + captures + args = tensor_inputs + self._captured_inputs - if tape.should_record(tensor_inputs) or tape.should_record(captures): + if (tape.should_record(tensor_inputs) or + tape.should_record(self._captured_inputs)): return self._backprop_call(args) # Only need to override the gradient in graph mode and when we have outputs. @@ -804,32 +793,6 @@ class Function(object): args, backward_function) return self._build_call_outputs(real_outputs) - def _resolve_captured_inputs(self): - """Resolve captured distributed variables to their current values. - - Some inputs can be distributed variables. Such variables yield a different - component (i.e. actual tf.Variable) variables depending on the context of - execution. - - Returns: - a list of resolved captured input tensors. - """ - if self._distributed_variables: - # Loop over each captured input and check if it corresponds to something - # distributed. If so, get its _distributed_container and fetch the - # component appropriate for the current execution context. - resolved_captured_inputs = self._captured_inputs[:] - for i, captured_input in enumerate(self._captured_inputs): - distributed_var = self._distributed_variables.get(captured_input, None) - if distributed_var is not None: - # distributed variables override __getattr__ and substitute the - # right component variable. In here, `distributed_var.handle` - # actually does the equivalent of - # distributed_var.get_current_component_var().handle. - resolved_captured_inputs[i] = distributed_var.handle - return resolved_captured_inputs - return self._captured_inputs - def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -1010,14 +973,6 @@ def func_graph_from_py_func(name, for x in _flatten(func_graph.structured_outputs) if x is not None) - # Some captured variables might be components of DistributedValues. - # Instead of storing non-distributed component variables, we - # store their distributed containers so we can retrieve the correct - # component variables at call-time. - strategy = distribution_strategy_context.get_distribution_strategy() - for i, variable in enumerate(variables): - # If variable is not distributed value_container returns itself. - variables[i] = strategy.value_container(variable) func_graph.variables = variables # Register any other functions defined in the graph. diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 399d90223c..ade945f874 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -21,6 +21,15 @@ from __future__ import print_function import contextlib from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util.lazy_loader import LazyLoader + +# There is a circular dependency between this, ops.py, and +# distribution_strategy_context. +# TODO(b/117329403): Remove this circular dependency. +distribution_strategy_context = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training." + "distribution_strategy_context") class Tape(object): @@ -52,12 +61,28 @@ def watch(tape, tensor): def watch_variable(tape, variable): """Marks this variable to be watched by the given tape.""" - pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access def variable_accessed(variable): - """Notifies all tapes in the stack that a variable has been accessed.""" - pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) + """Notifies all tapes in the stack that a variable has been accessed. + + Args: + variable: variable to be watched. + """ + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) def pop_tape(tape): -- cgit v1.2.3