aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-10-09 15:07:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 15:12:12 -0700
commit5f69248a692f7b47ea11930621f4f19d0397fe8c (patch)
treee6ae69c17d798afc96ba83644bf2ce6656181856 /tensorflow/python/eager/tape.py
parentc1093a3757224257fed0f7a1959d0fc99d5c757f (diff)
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py31
1 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 399d90223c..ade945f874 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -21,6 +21,15 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# There is a circular dependency between this, ops.py, and
+# distribution_strategy_context.
+# TODO(b/117329403): Remove this circular dependency.
+distribution_strategy_context = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training."
+ "distribution_strategy_context")
class Tape(object):
@@ -52,12 +61,28 @@ def watch(tape, tensor):
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
- """Notifies all tapes in the stack that a variable has been accessed."""
- pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
+ """Notifies all tapes in the stack that a variable has been accessed.
+
+ Args:
+ variable: variable to be watched.
+ """
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):