diff options
author | Igor Ganichev <iga@google.com> | 2018-10-09 15:07:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 15:12:12 -0700 |
commit | 5f69248a692f7b47ea11930621f4f19d0397fe8c (patch) | |
tree | e6ae69c17d798afc96ba83644bf2ce6656181856 /tensorflow/python/eager/tape.py | |
parent | c1093a3757224257fed0f7a1959d0fc99d5c757f (diff) |
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture
distributed variables instead of plain ResourceVariables.
In other words, we move the distribution awareness from defun
down to tape and rely on distributed variable magic to provide us
with the right variable at runtime.
In tower context, we always watch the container (e.g. MirroredVariable).
In cross tower context, we always watch all the components.
PiperOrigin-RevId: 216430530
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 399d90223c..ade945f874 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -21,6 +21,15 @@ from __future__ import print_function import contextlib from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util.lazy_loader import LazyLoader + +# There is a circular dependency between this, ops.py, and +# distribution_strategy_context. +# TODO(b/117329403): Remove this circular dependency. +distribution_strategy_context = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training." + "distribution_strategy_context") class Tape(object): @@ -52,12 +61,28 @@ def watch(tape, tensor): def watch_variable(tape, variable): """Marks this variable to be watched by the given tape.""" - pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access def variable_accessed(variable): - """Notifies all tapes in the stack that a variable has been accessed.""" - pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) + """Notifies all tapes in the stack that a variable has been accessed. + + Args: + variable: variable to be watched. + """ + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) def pop_tape(tape): |