aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py31
1 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 399d90223c..ade945f874 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -21,6 +21,15 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# There is a circular dependency between this, ops.py, and
+# distribution_strategy_context.
+# TODO(b/117329403): Remove this circular dependency.
+distribution_strategy_context = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training."
+ "distribution_strategy_context")
class Tape(object):
@@ -52,12 +61,28 @@ def watch(tape, tensor):
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
- """Notifies all tapes in the stack that a variable has been accessed."""
- pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
+ """Notifies all tapes in the stack that a variable has been accessed.
+
+ Args:
+ variable: variable to be watched.
+ """
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):