aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/backprop.py67
-rw-r--r--tensorflow/python/eager/backprop_test.py15
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h3
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc8
-rw-r--r--tensorflow/python/eager/tape.py5
-rw-r--r--tensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt8
7 files changed, 107 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 4cdf0a41ad..773c981195 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -751,6 +752,72 @@ class GradientTape(object):
for t in nest.flatten(tensor):
tape.watch(_handle_or_self(t))
+ @tf_contextlib.contextmanager
+ def stop_recording(self):
+ """Temporarily stops recording operations on this tape.
+
+ Operations executed while this context manager is active will not be
+ recorded on the tape. This is useful for reducing the memory used by tracing
+ all computations.
+
+ For example:
+
+ ```
+ with tf.GradientTape(persistent=True) as t:
+ loss = compute_loss(model)
+ with t.stop_recording():
+ # The gradient computation below is not traced, saving memory.
+ grads = t.gradient(loss, model.variables)
+ ```
+
+ Yields:
+ None
+ Raises:
+ RuntimeError: if the tape is not currently recording.
+ """
+ if self._tape is None:
+ raise RuntimeError(
+ "Trying to stop recording a tape which is not recording.")
+ tape.pop_tape(self._tape)
+ try:
+ yield
+ finally:
+ tape.push_tape(self._tape)
+
+ def reset(self):
+ """Clears all information stored in this tape.
+
+ Equivalent to exiting and reentering the tape context manager with a new
+ tape. For example, the two following code blocks are equivalent:
+ ```
+ with tf.GradientTape() as t:
+ loss = loss_fn()
+ with tf.GradientTape() as t:
+ loss += other_loss_fn()
+ t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
+
+
+ # The following is equivalent to the above
+ with tf.GradientTape() as t:
+ loss = loss_fn()
+ t.reset()
+ loss += other_loss_fn()
+ t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
+ ```
+
+ This is useful if you don't want to exit the context manager for the tape,
+ or can't because the desired reset point is inside a control flow construct:
+
+ ```
+ with tf.GradientTape() as t:
+ loss = ...
+ if loss > k:
+ t.reset()
+ ```
+ """
+ self.__exit__(None, None, None)
+ self.__enter__()
+
def watched_variables(self):
# Sorting variables by id, which is monotonically increasing in construction
# order. This ensures unique order across executions.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index d4b3c8bb5f..9aaa2e33c9 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -221,6 +221,21 @@ class BackpropTest(test.TestCase):
self.assertTrue(ordered_variables[0] is v0)
self.assertTrue(ordered_variables[1] is v1)
+ def testTapeStopRecording(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(1.0)
+ with t.stop_recording():
+ y = x * x
+ self.assertEqual(t.gradient(y, x), None)
+
+ def testTapeReset(self):
+ with backprop.GradientTape() as t:
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ t.reset()
+ loss += v * v
+ self.assertAllEqual(t.gradient(loss, v), 2.0)
+
@test_util.assert_no_new_tensors
def testGradientNone(self):
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 691b613e48..9bc8b9bc72 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -120,6 +120,9 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
+// Adds the passed tape to the set of active tapes.
+void TFE_Py_TapeSetAdd(PyObject* tape);
+
// Returns true if the tape stack is empty.
PyObject* TFE_Py_TapeSetIsEmpty();
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 48a5b21dc7..0f21a91a83 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1009,6 +1009,14 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
return reinterpret_cast<PyObject*>(tape);
}
+void TFE_Py_TapeSetAdd(PyObject* tape) {
+ Py_INCREF(tape);
+ if (!GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)).second) {
+ // Already exists in the tape set.
+ Py_DECREF(tape);
+ }
+}
+
PyObject* TFE_Py_TapeSetIsEmpty() {
if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
Py_RETURN_TRUE;
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index ad82266bec..caa217b70c 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -39,6 +39,11 @@ def push_new_tape(persistent=False):
return Tape(tape)
+def push_tape(tape):
+ """Pushes an existing tape onto the tape stack."""
+ pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
+
+
def watch(tensor):
"""Marks this tensor to be watched by all tapes in the stack.
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5ee55301df..fde3223e96 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -42,6 +42,7 @@ limitations under the License.
%rename("%s") TFE_Py_RecordGradient;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeSetNew;
+%rename("%s") TFE_Py_TapeSetAdd;
%rename("%s") TFE_Py_TapeSetRemove;
%rename("%s") TFE_Py_TapeSetStopOnThread;
%rename("%s") TFE_Py_TapeSetRestartOnThread;
diff --git a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt
index 7405202b89..cbf655498c 100644
--- a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt
@@ -11,6 +11,14 @@ tf_class {
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "reset"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "stop_recording"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "watch"
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
}