aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tom Hennigan <tomhennigan@google.com>2018-09-07 02:56:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 03:00:56 -0700
commitc35e6928dc0b00ca3ca489e5d34f856499cece4a (patch)
tree799f8dffe4ea6bf096dcd4ecc178bcf34f4ed19f
parent5a635e3472e16007830fca533c35b2f63fc4f898 (diff)
Support not automatically watching (trainable) accessed variables in GradientTape.
For more complex use cases this allows fine grained control over what is tracked by the tape. PiperOrigin-RevId: 211948236
-rw-r--r--tensorflow/python/eager/backprop.py52
-rw-r--r--tensorflow/python/eager/backprop_test.py14
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/function_test.py9
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h12
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc35
-rw-r--r--tensorflow/python/eager/tape.py18
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py6
-rwxr-xr-xtensorflow/python/pywrap_tfe.i3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt2
11 files changed, 118 insertions, 37 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index e9ebb57689..dda961c5f6 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -693,19 +693,57 @@ class GradientTape(object):
del g # Drop the reference to the tape
```
+ By default GradientTape will automatically watch any trainable variables that
+ are accessed inside the context. If you want fine grained control over which
+ variables are watched you can disable automatic tracking by passing
+ `watch_accessed_variables=False` to the tape constructor:
+
+ ```python
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(variable_a)
+ y = variable_a ** 2 # Gradients will be available for `variable_a`.
+ z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is
+ # not being watched.
+ ```
+
+ Note that when using models you should ensure that your variables exist when
+ using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
+ first iteration not have any gradients:
+
+ ```python
+ a = tf.keras.layers.Dense(32)
+ b = tf.keras.layers.Dense(32)
+
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(a.variables) # Since `a.build` has not been called at this point
+ # `a.variables` will return an empty list and the
+ # tape will not be watching anything.
+ result = b(a(inputs))
+ tape.gradient(result, a.variables) # The result of this computation will be
+ # a list of `None`s since a's variables
+ # are not being watched.
+ ```
+
Note that only tensors with real or complex dtypes are differentiable.
"""
- def __init__(self, persistent=False):
+ def __init__(self, persistent=False, watch_accessed_variables=True):
"""Creates a new GradientTape.
Args:
persistent: Boolean controlling whether a persistent gradient tape
is created. False by default, which means at most one call can
be made to the gradient() method on this object.
+ watch_accessed_variables: Boolean controlling whether the tape will
+ automatically `watch` any (trainable) variables accessed while the tape
+ is active. Defaults to True meaning gradients can be requested from any
+ result computed in the tape derived from reading a trainable `Variable`.
+ If False users must explicitly `watch` any `Variable`s they want to
+ request gradients from.
"""
self._tape = None
self._persistent = persistent
+ self._watch_accessed_variables = watch_accessed_variables
self._recording = False
context.context().start_step()
@@ -727,7 +765,9 @@ class GradientTape(object):
raise ValueError("There is no existing tape.")
tape.push_tape(self._tape)
else:
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ self._tape = tape.push_new_tape(
+ persistent=self._persistent,
+ watch_accessed_variables=self._watch_accessed_variables)
self._recording = True
def _pop_tape(self):
@@ -746,7 +786,13 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- tape.watch(self._tape, _handle_or_self(t))
+ if hasattr(t, "handle"):
+ # There are many variable-like objects, all of them currently have
+ # `handle` attribute that points to a tensor. If this changes, internals
+ # of watch_variable need to change as well.
+ tape.watch_variable(self._tape, t)
+ else:
+ tape.watch(self._tape, t)
@tf_contextlib.contextmanager
def stop_recording(self):
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 3319b440b4..65d57d3957 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -956,6 +956,20 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grad1, grad2)
+ @test_util.run_in_graph_and_eager_modes
+ def testSelectivelyWatchVariables(self):
+ x1 = resource_variable_ops.ResourceVariable(1.0)
+ x2 = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(x2)
+ y = x1**2
+ z = x2**3
+ self.assertTupleEqual(tape.watched_variables(), (x2,))
+ dy, dz = tape.gradient([y, z], [x1, x2])
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertIsNone(dy)
+ self.assertEqual(self.evaluate(dz), 3.0)
+
@test_util.run_in_graph_and_eager_modes
def testDifferentiatingScalarCache(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index d56c1457e0..03f12139f6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -519,7 +519,7 @@ class Function(object):
for v in self._func_graph.variables:
if v.trainable:
- tape.watch_variable(v)
+ tape.variable_accessed(v)
captures = self._resolve_captured_inputs()
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 3c79099d87..37a9957cea 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -616,7 +615,6 @@ class FunctionTest(test.TestCase):
@function.defun
def g(x):
- tape.watch_variable(x)
y = math_ops.add(x, three)
f(y)
@@ -630,7 +628,6 @@ class FunctionTest(test.TestCase):
return math_ops.add(x, three)
def g(x):
- tape.watch_variable(three)
return f(x)
g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
@@ -1427,14 +1424,14 @@ class FunctionTest(test.TestCase):
grad_t, = backprop.gradients_function(sq, [0])(t)
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
- with backprop.GradientTape(persistent=True) as gtape:
- gtape.watch(t)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(t)
one = matmul(t, b=t, transpose_a=True)
two = matmul(b=t, a=t, transpose_a=True)
three = matmul(a=t, b=t, transpose_a=True)
for output in [one, two, three]:
- self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+ self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
def testGradientInFunctionWithKeywordArguments(self):
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 6c1bd76296..f1b4042ec9 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -128,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// To unset the profiler, pass Py_None as the value of `profiler`.
PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -162,8 +163,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
// Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6ac9ed081a..1ed814258b 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
- explicit GradientTape(bool persistent)
+ explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent) {}
+ persistent),
+ watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
for (const IdAndVariable& v : watched_variables_) {
@@ -902,6 +903,12 @@ class GradientTape
}
}
+ void VariableAccessed(PyObject* v) {
+ if (watch_accessed_variables_) {
+ WatchVariable(v);
+ }
+ }
+
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
@@ -951,6 +958,7 @@ class GradientTape
}
};
+ bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
GUARDED_BY(watched_variables_mu_);
@@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
- tape->tape = new GradientTape(persistent == Py_True);
+ tape->tape = new GradientTape(persistent == Py_True,
+ watch_accessed_variables == Py_True);
Py_INCREF(tape);
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
return reinterpret_cast<PyObject*>(tape);
@@ -1233,13 +1243,20 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+void TFE_Py_TapeVariableAccessed(PyObject* variable) {
if (*ThreadTapeIsStopped()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- tape->tape->WatchVariable(variable);
+ tape->tape->VariableAccessed(variable);
+ }
+}
+
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
+ if (*ThreadTapeIsStopped()) {
+ return;
}
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
@@ -1909,14 +1926,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
Py_RETURN_NONE;
}
-void MaybeWatchVariable(PyObject* input) {
+void MaybeNotifyVariableAccessed(PyObject* input) {
DCHECK(CheckResourceVariable(input));
DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
- TFE_Py_TapeSetWatchVariable(input);
+ TFE_Py_TapeVariableAccessed(input);
}
bool CastTensor(const FastPathOpExecInfo& op_exec_info,
@@ -1947,7 +1964,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
- MaybeWatchVariable(input);
+ MaybeNotifyVariableAccessed(input);
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 6eb62afec4..399d90223c 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -33,9 +33,10 @@ class Tape(object):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
-def push_new_tape(persistent=False):
+def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
- tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+ tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
+ watch_accessed_variables)
return Tape(tape)
@@ -49,13 +50,14 @@ def watch(tape, tensor):
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
-def watch_variable(variable):
- """Marks this variable to be watched by all tapes in the stack.
+def watch_variable(tape, variable):
+ """Marks this variable to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
- Args:
- variable: variable to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
+
+def variable_accessed(variable):
+ """Notifies all tapes in the stack that a variable has been accessed."""
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
def pop_tape(tape):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 9a5629e0eb..55c2eb5fa4 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -750,7 +750,7 @@ class ResourceVariable(variables.RefVariable):
def _read_variable_op(self):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
if not context.executing_eagerly():
@@ -781,7 +781,7 @@ class ResourceVariable(variables.RefVariable):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
@@ -949,7 +949,7 @@ class ResourceVariable(variables.RefVariable):
def _lazy_read(self, op):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
return _UnreadVariable(
handle=self._handle, dtype=self.dtype, shape=self._shape,
in_graph_mode=self._in_graph_mode,
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 2253edc742..be8f425481 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -52,9 +52,10 @@ limitations under the License.
%rename("%s") TFE_Py_TapeSetShouldRecord;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
-%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
+%rename("%s") TFE_Py_TapeVariableAccessed;
%rename("%s") TFE_Py_TapeWatch;
+%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"