diff options
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 399d90223c..ade945f874 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -21,6 +21,15 @@ from __future__ import print_function import contextlib from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util.lazy_loader import LazyLoader + +# There is a circular dependency between this, ops.py, and +# distribution_strategy_context. +# TODO(b/117329403): Remove this circular dependency. +distribution_strategy_context = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training." + "distribution_strategy_context") class Tape(object): @@ -52,12 +61,28 @@ def watch(tape, tensor): def watch_variable(tape, variable): """Marks this variable to be watched by the given tape.""" - pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access def variable_accessed(variable): - """Notifies all tapes in the stack that a variable has been accessed.""" - pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) + """Notifies all tapes in the stack that a variable has been accessed. + + Args: + variable: variable to be watched. + """ + strategy = distribution_strategy_context.get_distribution_strategy() + if distribution_strategy_context.get_tower_context(): + variables = [strategy.value_container(variable)] + else: + variables = strategy.unwrap(variable) + for var in variables: + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) def pop_tape(tape): |