diff options
author | Tom Hennigan <tomhennigan@google.com> | 2018-09-07 02:56:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 03:00:56 -0700 |
commit | c35e6928dc0b00ca3ca489e5d34f856499cece4a (patch) | |
tree | 799f8dffe4ea6bf096dcd4ecc178bcf34f4ed19f | |
parent | 5a635e3472e16007830fca533c35b2f63fc4f898 (diff) |
Support not automatically watching (trainable) accessed variables in GradientTape.
For more complex use cases this allows fine grained control over what is tracked
by the tape.
PiperOrigin-RevId: 211948236
-rw-r--r-- | tensorflow/python/eager/backprop.py | 52 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 9 | ||||
-rwxr-xr-x | tensorflow/python/eager/pywrap_tfe.h | 12 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 35 | ||||
-rw-r--r-- | tensorflow/python/eager/tape.py | 18 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 6 | ||||
-rwxr-xr-x | tensorflow/python/pywrap_tfe.i | 3 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt | 2 |
11 files changed, 118 insertions, 37 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index e9ebb57689..dda961c5f6 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -693,19 +693,57 @@ class GradientTape(object): del g # Drop the reference to the tape ``` + By default GradientTape will automatically watch any trainable variables that + are accessed inside the context. If you want fine grained control over which + variables are watched you can disable automatic tracking by passing + `watch_accessed_variables=False` to the tape constructor: + + ```python + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(variable_a) + y = variable_a ** 2 # Gradients will be available for `variable_a`. + z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is + # not being watched. + ``` + + Note that when using models you should ensure that your variables exist when + using `watch_accessed_variables=False`. Otherwise it's quite easy to make your + first iteration not have any gradients: + + ```python + a = tf.keras.layers.Dense(32) + b = tf.keras.layers.Dense(32) + + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(a.variables) # Since `a.build` has not been called at this point + # `a.variables` will return an empty list and the + # tape will not be watching anything. + result = b(a(inputs)) + tape.gradient(result, a.variables) # The result of this computation will be + # a list of `None`s since a's variables + # are not being watched. + ``` + Note that only tensors with real or complex dtypes are differentiable. """ - def __init__(self, persistent=False): + def __init__(self, persistent=False, watch_accessed_variables=True): """Creates a new GradientTape. Args: persistent: Boolean controlling whether a persistent gradient tape is created. False by default, which means at most one call can be made to the gradient() method on this object. + watch_accessed_variables: Boolean controlling whether the tape will + automatically `watch` any (trainable) variables accessed while the tape + is active. Defaults to True meaning gradients can be requested from any + result computed in the tape derived from reading a trainable `Variable`. + If False users must explicitly `watch` any `Variable`s they want to + request gradients from. """ self._tape = None self._persistent = persistent + self._watch_accessed_variables = watch_accessed_variables self._recording = False context.context().start_step() @@ -727,7 +765,9 @@ class GradientTape(object): raise ValueError("There is no existing tape.") tape.push_tape(self._tape) else: - self._tape = tape.push_new_tape(persistent=self._persistent) + self._tape = tape.push_new_tape( + persistent=self._persistent, + watch_accessed_variables=self._watch_accessed_variables) self._recording = True def _pop_tape(self): @@ -746,7 +786,13 @@ class GradientTape(object): tensor: a Tensor or list of Tensors. """ for t in nest.flatten(tensor): - tape.watch(self._tape, _handle_or_self(t)) + if hasattr(t, "handle"): + # There are many variable-like objects, all of them currently have + # `handle` attribute that points to a tensor. If this changes, internals + # of watch_variable need to change as well. + tape.watch_variable(self._tape, t) + else: + tape.watch(self._tape, t) @tf_contextlib.contextmanager def stop_recording(self): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 3319b440b4..65d57d3957 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -956,6 +956,20 @@ class BackpropTest(test.TestCase): self.assertAllEqual(grad1, grad2) + @test_util.run_in_graph_and_eager_modes + def testSelectivelyWatchVariables(self): + x1 = resource_variable_ops.ResourceVariable(1.0) + x2 = resource_variable_ops.ResourceVariable(1.0) + with backprop.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(x2) + y = x1**2 + z = x2**3 + self.assertTupleEqual(tape.watched_variables(), (x2,)) + dy, dz = tape.gradient([y, z], [x1, x2]) + self.evaluate([x1.initializer, x2.initializer]) + self.assertIsNone(dy) + self.assertEqual(self.evaluate(dz), 3.0) + @test_util.run_in_graph_and_eager_modes def testDifferentiatingScalarCache(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d56c1457e0..03f12139f6 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -519,7 +519,7 @@ class Function(object): for v in self._func_graph.variables: if v.trainable: - tape.watch_variable(v) + tape.variable_accessed(v) captures = self._resolve_captured_inputs() tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 3c79099d87..37a9957cea 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -616,7 +615,6 @@ class FunctionTest(test.TestCase): @function.defun def g(x): - tape.watch_variable(x) y = math_ops.add(x, three) f(y) @@ -630,7 +628,6 @@ class FunctionTest(test.TestCase): return math_ops.add(x, three) def g(x): - tape.watch_variable(three) return f(x) g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0] @@ -1427,14 +1424,14 @@ class FunctionTest(test.TestCase): grad_t, = backprop.gradients_function(sq, [0])(t) self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) - with backprop.GradientTape(persistent=True) as gtape: - gtape.watch(t) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(t) one = matmul(t, b=t, transpose_a=True) two = matmul(b=t, a=t, transpose_a=True) three = matmul(a=t, b=t, transpose_a=True) for output in [one, two, three]: - self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]]) def testGradientInFunctionWithKeywordArguments(self): diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 6c1bd76296..f1b4042ec9 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -128,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); // To unset the profiler, pass Py_None as the value of `profiler`. PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); -// Creates a new tape and adds it to the active set. `persistent` must be a -// PyBool_Type, i.e either Py_True or Py_False -PyObject* TFE_Py_TapeSetNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` and +// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`). +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); @@ -162,8 +163,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, PyObject* input_tensor_ids, PyObject* backward_function); +// Notifies all tapes that a variable has been accessed. +void TFE_Py_TapeVariableAccessed(PyObject* variable); + // Watches the given variable object on the given tape. -void TFE_Py_TapeSetWatchVariable(PyObject* variable); +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must // have been produced by TFE_Py_NewTape. `target` and `sources` must be python diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6ac9ed081a..1ed814258b 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { class GradientTape : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> { public: - explicit GradientTape(bool persistent) + explicit GradientTape(bool persistent, bool watch_accessed_variables) : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>( - persistent) {} + persistent), + watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { for (const IdAndVariable& v : watched_variables_) { @@ -902,6 +903,12 @@ class GradientTape } } + void VariableAccessed(PyObject* v) { + if (watch_accessed_variables_) { + WatchVariable(v); + } + } + void WatchVariable(PyObject* v) { tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); if (handle == nullptr) { @@ -951,6 +958,7 @@ class GradientTape } }; + bool watch_accessed_variables_; tensorflow::mutex watched_variables_mu_; std::set<IdAndVariable, CompareById> watched_variables_ GUARDED_BY(watched_variables_mu_); @@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } -PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); - tape->tape = new GradientTape(persistent == Py_True); + tape->tape = new GradientTape(persistent == Py_True, + watch_accessed_variables == Py_True); Py_INCREF(tape); GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)); return reinterpret_cast<PyObject*>(tape); @@ -1233,13 +1243,20 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeSetWatchVariable(PyObject* variable) { +void TFE_Py_TapeVariableAccessed(PyObject* variable) { if (*ThreadTapeIsStopped()) { return; } for (TFE_Py_Tape* tape : SafeTapeSet()) { - tape->tape->WatchVariable(variable); + tape->tape->VariableAccessed(variable); + } +} + +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { + if (*ThreadTapeIsStopped()) { + return; } + reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { @@ -1909,14 +1926,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, Py_RETURN_NONE; } -void MaybeWatchVariable(PyObject* input) { +void MaybeNotifyVariableAccessed(PyObject* input) { DCHECK(CheckResourceVariable(input)); DCHECK(PyObject_HasAttrString(input, "_trainable")); tensorflow::Safe_PyObjectPtr trainable( PyObject_GetAttrString(input, "_trainable")); if (trainable.get() == Py_False) return; - TFE_Py_TapeSetWatchVariable(input); + TFE_Py_TapeVariableAccessed(input); } bool CastTensor(const FastPathOpExecInfo& op_exec_info, @@ -1947,7 +1964,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info, bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output, TF_Status* status) { - MaybeWatchVariable(input); + MaybeNotifyVariableAccessed(input); TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 6eb62afec4..399d90223c 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -33,9 +33,10 @@ class Tape(object): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) -def push_new_tape(persistent=False): +def push_new_tape(persistent=False, watch_accessed_variables=True): """Pushes a new tape onto the tape stack.""" - tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent) + tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, + watch_accessed_variables) return Tape(tape) @@ -49,13 +50,14 @@ def watch(tape, tensor): pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access -def watch_variable(variable): - """Marks this variable to be watched by all tapes in the stack. +def watch_variable(tape, variable): + """Marks this variable to be watched by the given tape.""" + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access - Args: - variable: variable to be watched. - """ - pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable) + +def variable_accessed(variable): + """Notifies all tapes in the stack that a variable has been accessed.""" + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) def pop_tape(tape): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 9a5629e0eb..55c2eb5fa4 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -750,7 +750,7 @@ class ResourceVariable(variables.RefVariable): def _read_variable_op(self): if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) result = gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) if not context.executing_eagerly(): @@ -781,7 +781,7 @@ class ResourceVariable(variables.RefVariable): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) return array_ops.identity(value) @@ -949,7 +949,7 @@ class ResourceVariable(variables.RefVariable): def _lazy_read(self, op): if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) return _UnreadVariable( handle=self._handle, dtype=self.dtype, shape=self._shape, in_graph_mode=self._in_graph_mode, diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 2253edc742..be8f425481 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -52,9 +52,10 @@ limitations under the License. %rename("%s") TFE_Py_TapeSetShouldRecord; %rename("%s") TFE_Py_TapeSetDeleteTrace; %rename("%s") TFE_Py_TapeSetRecordOperation; -%rename("%s") TFE_Py_TapeSetWatchVariable; %rename("%s") TFE_Py_TapeGradient; +%rename("%s") TFE_Py_TapeVariableAccessed; %rename("%s") TFE_Py_TapeWatch; +%rename("%s") TFE_Py_TapeWatchVariable; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt index cbf655498c..2f4257a66a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " } member_method { name: "gradient" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt index cbf655498c..2f4257a66a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " } member_method { name: "gradient" |