aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-10-09 15:07:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 15:12:12 -0700
commit5f69248a692f7b47ea11930621f4f19d0397fe8c (patch)
treee6ae69c17d798afc96ba83644bf2ce6656181856 /tensorflow/python
parentc1093a3757224257fed0f7a1959d0fc99d5c757f (diff)
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/backprop_test.py24
-rw-r--r--tensorflow/python/eager/function.py53
-rw-r--r--tensorflow/python/eager/tape.py31
3 files changed, 56 insertions, 52 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 7e5c9f3cb6..b1b20fafd2 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -258,6 +258,30 @@ class BackpropTest(test.TestCase):
loss += v * v
self.assertAllEqual(t.gradient(loss, v), 2.0)
+ def testAutomaticWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ loss += v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ def testExplicitWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
@test_util.assert_no_new_tensors
def testGradientNone(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ff138cad1e..f1a63adce1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -202,6 +201,7 @@ class FuncGraph(ops.Graph):
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
+ self._variable_creator_stack = graph._variable_creator_stack
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key
@@ -563,17 +563,6 @@ class Function(object):
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
self._backward_graph_function = None
- # Map holding distributed variables, keyed by resource handle tensors.
- self._distributed_variables = {}
- strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._func_graph.variables:
- # If variable is not distributed, unwrap returns [variable].
- component_variables = strategy.unwrap(variable)
- # Only update the dictionary when the variable is actually distributed.
- if (len(component_variables) > 1 or component_variables[0] != variable):
- for component_variable in component_variables:
- self._distributed_variables[component_variable.handle] = variable
-
def __call__(self, *args):
"""Executes the wrapped function.
@@ -602,7 +591,6 @@ class Function(object):
if v.trainable:
tape.variable_accessed(v)
- captures = self._resolve_captured_inputs()
tensor_inputs = []
for i, arg in enumerate(nest.flatten(args)):
if isinstance(arg, resource_variable_ops.ResourceVariable):
@@ -615,9 +603,10 @@ class Function(object):
raise ValueError("All inputs to `Function`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
- args = tensor_inputs + captures
+ args = tensor_inputs + self._captured_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ if (tape.should_record(tensor_inputs) or
+ tape.should_record(self._captured_inputs)):
return self._backprop_call(args)
# Only need to override the gradient in graph mode and when we have outputs.
@@ -804,32 +793,6 @@ class Function(object):
args, backward_function)
return self._build_call_outputs(real_outputs)
- def _resolve_captured_inputs(self):
- """Resolve captured distributed variables to their current values.
-
- Some inputs can be distributed variables. Such variables yield a different
- component (i.e. actual tf.Variable) variables depending on the context of
- execution.
-
- Returns:
- a list of resolved captured input tensors.
- """
- if self._distributed_variables:
- # Loop over each captured input and check if it corresponds to something
- # distributed. If so, get its _distributed_container and fetch the
- # component appropriate for the current execution context.
- resolved_captured_inputs = self._captured_inputs[:]
- for i, captured_input in enumerate(self._captured_inputs):
- distributed_var = self._distributed_variables.get(captured_input, None)
- if distributed_var is not None:
- # distributed variables override __getattr__ and substitute the
- # right component variable. In here, `distributed_var.handle`
- # actually does the equivalent of
- # distributed_var.get_current_component_var().handle.
- resolved_captured_inputs[i] = distributed_var.handle
- return resolved_captured_inputs
- return self._captured_inputs
-
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -1010,14 +973,6 @@ def func_graph_from_py_func(name,
for x in _flatten(func_graph.structured_outputs)
if x is not None)
- # Some captured variables might be components of DistributedValues.
- # Instead of storing non-distributed component variables, we
- # store their distributed containers so we can retrieve the correct
- # component variables at call-time.
- strategy = distribution_strategy_context.get_distribution_strategy()
- for i, variable in enumerate(variables):
- # If variable is not distributed value_container returns itself.
- variables[i] = strategy.value_container(variable)
func_graph.variables = variables
# Register any other functions defined in the graph.
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 399d90223c..ade945f874 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -21,6 +21,15 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# There is a circular dependency between this, ops.py, and
+# distribution_strategy_context.
+# TODO(b/117329403): Remove this circular dependency.
+distribution_strategy_context = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training."
+ "distribution_strategy_context")
class Tape(object):
@@ -52,12 +61,28 @@ def watch(tape, tensor):
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
- """Notifies all tapes in the stack that a variable has been accessed."""
- pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
+ """Notifies all tapes in the stack that a variable has been accessed.
+
+ Args:
+ variable: variable to be watched.
+ """
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):