diff options
author | Tom Hennigan <tomhennigan@google.com> | 2018-09-01 05:48:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-01 05:51:59 -0700 |
commit | 543bc6a8b98e8d08ce9c25dcafe629e124f266eb (patch) | |
tree | 31071254618c914eadc814253ad6edef8b68da92 /tensorflow/python/eager/tape.py | |
parent | dcbd97c2939ddbfadf9e21f1c0e6aa720b9154d6 (diff) |
Only watch tensors on the current tape rather than all of them.
This allows fine grained control over recording in some cases, for example the
following where we want d2y but not d2z:
x1 = tf.Variable(2.0, trainable=False)
x2 = tf.Variable(2.0, trainable=False)
with tf.GradientTape() as tape1:
with tf.GradientTape() as tape2:
tape1.watch(x1)
tape2.watch([x1, x2])
y = x1 ** 3
z = x2 ** 2
dy, dz = tape2.gradient([y, z], [x1, x2])
d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
assert d2z is None
PiperOrigin-RevId: 211206506
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index caa217b70c..6eb62afec4 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -44,13 +44,9 @@ def push_tape(tape): pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access -def watch(tensor): - """Marks this tensor to be watched by all tapes in the stack. - - Args: - tensor: tensor to be watched. - """ - pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor) +def watch(tape, tensor): + """Marks this tensor to be watched by the given tape.""" + pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access def watch_variable(variable): |