aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-08 17:59:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 18:03:58 -0800
commitcaa2a1b856d1080cfec26b6ab5756aa49114597e (patch)
tree447cce319a13ba32caf20c83d2497e4b292c7c86 /tensorflow
parent6fda97ab7540fb003d46f7c9810d6aab6dbc6c19 (diff)
Fix the threading model of gradient tapes.
The set of tapes needs to be global to enable multithreaded programming (when it's natural for tensors to cross threads during reduction operations) but each thread still needs to be able to locally pause recording while it does gradient-related bookkeeping (like custom gradients or initialization). Also removes a mutex from the thread-local structure since it's unnecessary as we're always holding the GIL while calling across the python-c boundary unless we explicitly release it. PiperOrigin-RevId: 181246570
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/eager/backprop.py18
-rw-r--r--tensorflow/python/eager/function.py5
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc2
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h36
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc115
-rw-r--r--tensorflow/python/eager/tape.py26
-rw-r--r--tensorflow/python/pywrap_tfe.i19
7 files changed, 115 insertions, 106 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index a06feb1669..56a49301a2 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -344,7 +344,7 @@ def implicit_val_and_grad(f):
def grad_fn(*args):
"""Computes the gradient of the wrapped function."""
- tape.push_new_tape()
+ this_tape = tape.push_new_tape()
try:
end_node = f(*args)
if end_node is None:
@@ -352,10 +352,10 @@ def implicit_val_and_grad(f):
"did you forget to return a value from {}?".format(
f.__name__))
finally:
- popped_tape = tape.pop_tape()
+ tape.pop_tape(this_tape)
# Sorting variables by id, which is monotonically increasing in construction
# order. This ensures unique order across executions.
- variables = list(sorted(popped_tape.watched_variables(),
+ variables = list(sorted(this_tape.watched_variables(),
key=lambda v: v.handle._id)) # pylint: disable=protected-access
sources = [x.handle for x in variables]
@@ -363,7 +363,7 @@ def implicit_val_and_grad(f):
raise ValueError("No trainable variables were accessed while the "
"function was being computed.")
grad = imperative_grad.imperative_grad(_default_vspace,
- popped_tape,
+ this_tape,
nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -652,7 +652,7 @@ def make_vjp(f, params=None):
"""Computes the value and gradient of the decorated function."""
parameter_positions = _get_arg_spec(f, params, args)
assert not kwds, "The gradient function can't take keyword arguments."
- tape.push_new_tape()
+ this_tape = tape.push_new_tape()
try:
sources = []
args = [
@@ -673,12 +673,12 @@ def make_vjp(f, params=None):
flat_result = [gen_array_ops.identity(x) for x in flat_result]
result = nest.pack_sequence_as(result, flat_result)
finally:
- t = tape.pop_tape()
+ tape.pop_tape(this_tape)
def vjp(dy=None):
if dy is not None:
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
return imperative_grad.imperative_grad(
- _default_vspace, t, nest.flatten(result), sources,
+ _default_vspace, this_tape, nest.flatten(result), sources,
output_gradients=dy)
return result, vjp
@@ -835,11 +835,11 @@ class GradientTape(object):
self._persistent = persistent
def __enter__(self):
- tape.push_new_tape(persistent=self._persistent)
+ self._tape = tape.push_new_tape(persistent=self._persistent)
return self
def __exit__(self, typ, value, traceback):
- self._tape = tape.pop_tape()
+ tape.pop_tape(self._tape)
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 969e321dd1..f755434ad7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -544,11 +544,12 @@ def _defun_internal(name, func, args, kwds):
func_inputs = _get_defun_inputs(args)
with capture_tensors(captures):
- tape.push_new_tape()
+ this_tape = tape.push_new_tape()
try:
func_outputs = func(*func_inputs, **kwds)
finally:
- variables = tape.pop_tape().watched_variables()
+ tape.pop_tape(this_tape)
+ variables = this_tape.watched_variables()
# Returning a closed-over tensor as an output does not trigger a
# call to convert_to_tensor, so we manually capture all such tensors.
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 91192fea62..6fa076507d 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -332,7 +332,7 @@ void EagerTensor_dealloc(EagerTensor* self) {
tensorflow::ClearDecrefCache();
auto id = self->id;
Py_TYPE(self)->tp_free(self);
- TFE_Py_TapeStackDeleteTrace(id);
+ TFE_Py_TapeSetDeleteTrace(id);
}
// Getter for `_id`.
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a33b17ada6..cecef42603 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -87,22 +87,25 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// newly created type, or nullptr on error.
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
-// Pushes a new tape into the thread-local stack.
-// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False
-void TFE_Py_TapeStackPushNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` must be a
+// PyBool_Type, i.e either Py_True or Py_False
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
-// Pops the tape from the top of the stack and returns it.
-PyObject* TFE_Py_TapeStackPop();
-
-// Pushes an existing tape onto the stack.
-void TFE_Py_TapeStackPush(PyObject* tape);
+// Removes the passed tape from the set of active tapes.
+void TFE_Py_TapeSetRemove(PyObject* tape);
// Returns true if the tape stack is empty.
-PyObject* TFE_Py_TapeStackIsEmpty();
+PyObject* TFE_Py_TapeSetIsEmpty();
+
+PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
+void TFE_Py_TapeSetWatch(PyObject* tensor);
+void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
+
+// Stops any gradient recording on the current thread.
+void TFE_Py_TapeSetStopOnThread();
-PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors);
-void TFE_Py_TapeStackWatch(PyObject* tensor);
-void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
+// Restarts gradient recording on the current thread.
+void TFE_Py_TapeSetRestartOnThread();
// Records an operation in the gradient tape stack.type is a string for the
// operation type, used in the backprop code. output_tensors should be a list of
@@ -111,13 +114,12 @@ void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
// operation. backward_function should be the function to be called during
// backprop to, given the gradients of the output tensors, produce the gradients
// of the input tensors.
-void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
- PyObject* output_tensors,
- PyObject* input_tensor_ids,
- PyObject* backward_function);
+void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+ PyObject* input_tensor_ids,
+ PyObject* backward_function);
// Watches the given variable object on the given tape.
-void TFE_Py_TapeStackWatchVariable(PyObject* variable);
+void TFE_Py_TapeSetWatchVariable(PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
// have been produced by TFE_Py_NewTape. `vspace` must be a
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 3ba81fb3d0..bdaeccf286 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -538,62 +538,67 @@ static PyTypeObject TFE_Py_Tape_Type = {
"TFE_Py_Tape objects", /* tp_doc */
};
+// Note: in the current design no mutex is needed here because of the python
+// GIL, which is always held when any TFE_Py_* methods are called. We should
+// revisit this if/when decide to not hold the GIL while manipulating the tape
+// stack.
+static std::unordered_set<TFE_Py_Tape*>* tape_set = nullptr;
+std::unordered_set<TFE_Py_Tape*>* GetTapeSet() {
+ if (tape_set == nullptr) {
+ tape_set = new std::unordered_set<TFE_Py_Tape*>;
+ }
+ return tape_set;
+}
+
// xcode 7 doesn't define thread_local, so for compatibility we implement our
// own. TODO(apassos) remove once we can deprecate xcode 7.
#ifndef __APPLE__
-std::vector<TFE_Py_Tape*>* GetTapeStack() {
- thread_local std::vector<TFE_Py_Tape*> tape_stack;
- return &tape_stack;
+bool* ThreadTapeIsStopped() {
+ thread_local bool thread_tape_is_stopped{false};
+ return &thread_tape_is_stopped;
}
#else
-static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED);
-static std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>*
- tape_stack GUARDED_BY(stack_mu) = nullptr;
-std::vector<TFE_Py_Tape*>* GetTapeStack() {
- tensorflow::mutex_lock ml(stack_mu);
- if (tape_stack == nullptr) {
- tape_stack =
- new std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>;
- }
- auto it = tape_stack->find(std::this_thread::get_id());
- if (it != tape_stack->end()) {
- return it->second;
- }
- return tape_stack
- ->emplace(std::this_thread::get_id(), new std::vector<TFE_Py_Tape*>)
- .first->second;
+static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
+bool* ThreadTapeIsStopped() {
+ if (tape_is_stopped == nullptr) {
+ tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
+ }
+ auto it = tape_is_stopped->find(std::this_thread::get_id());
+ if (it != tape_is_stopped->end()) {
+ return &(it->second);
+ }
+ return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
+ .first->second);
}
#endif
-void TFE_Py_TapeStackPushNew(PyObject* persistent) {
+void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
+
+void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
+
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
- if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
+ if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
tape->tape = new GradientTape(persistent == Py_True);
- GetTapeStack()->push_back(tape);
-}
-
-void TFE_Py_TapeStackPush(PyObject* tape) {
Py_INCREF(tape);
- GetTapeStack()->push_back(reinterpret_cast<TFE_Py_Tape*>(tape));
+ GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
+ return reinterpret_cast<PyObject*>(tape);
}
-PyObject* TFE_Py_TapeStackIsEmpty() {
- if (GetTapeStack()->empty()) {
+PyObject* TFE_Py_TapeSetIsEmpty() {
+ if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
-PyObject* TFE_Py_TapeStackPop() {
- auto* stack = GetTapeStack();
- if (stack->empty()) {
- PyErr_SetString(PyExc_RuntimeError, "tape stack is empty.");
- return nullptr;
- }
- TFE_Py_Tape* top = stack->back();
- stack->pop_back();
- return reinterpret_cast<PyObject*>(top);
+void TFE_Py_TapeSetRemove(PyObject* tape) {
+ auto* stack = GetTapeSet();
+ stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
+ // We kept a reference to the tape in the set to ensure it wouldn't get
+ // deleted under us; cleaning it up here.
+ Py_DECREF(tape);
}
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
@@ -620,12 +625,15 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
return tensor_ids;
}
-PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
+PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
if (tensors == Py_None) {
Py_RETURN_FALSE;
}
- auto* stack = GetTapeStack();
- if (stack->empty()) {
+ if (*ThreadTapeIsStopped()) {
+ Py_RETURN_FALSE;
+ }
+ auto* tape_set = GetTapeSet();
+ if (tape_set->empty()) {
Py_RETURN_FALSE;
}
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -642,7 +650,7 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
tensor_ids.push_back(FastTensorId(item));
}
Py_DECREF(seq);
- for (TFE_Py_Tape* tape : *stack) {
+ for (TFE_Py_Tape* tape : *tape_set) {
if (tape->tape->ShouldRecord(tensor_ids)) {
Py_RETURN_TRUE;
}
@@ -650,12 +658,12 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
-void TFE_Py_TapeStackWatch(PyObject* tensor) {
+void TFE_Py_TapeSetWatch(PyObject* tensor) {
tensorflow::int64 tensor_id = FastTensorId(tensor);
if (PyErr_Occurred()) {
return;
}
- for (TFE_Py_Tape* tape : *GetTapeStack()) {
+ for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->Watch(tensor_id);
}
}
@@ -720,8 +728,8 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeStackWatchVariable(PyObject* variable) {
- for (TFE_Py_Tape* tape : *GetTapeStack()) {
+void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+ for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->WatchVariable(variable);
}
}
@@ -736,12 +744,11 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return result;
}
-void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
- PyObject* output_tensors,
- PyObject* input_tensors,
- PyObject* backward_function) {
- auto* stack = GetTapeStack();
- if (stack->empty()) {
+void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+ PyObject* input_tensors,
+ PyObject* backward_function) {
+ auto* set = GetTapeSet();
+ if (set->empty()) {
return;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
@@ -776,7 +783,7 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
return;
}
- for (TFE_Py_Tape* tape : *stack) {
+ for (TFE_Py_Tape* tape : *set) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
op_type_str, output_info, input_ids, backward_function,
@@ -784,8 +791,8 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
}
}
-void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) {
- for (TFE_Py_Tape* tape : *GetTapeStack()) {
+void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
+ for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->DeleteTrace(tensor_id);
}
}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 14b5238f74..ad82266bec 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -35,7 +35,8 @@ class Tape(object):
def push_new_tape(persistent=False):
"""Pushes a new tape onto the tape stack."""
- pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent)
+ tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+ return Tape(tape)
def watch(tensor):
@@ -44,7 +45,7 @@ def watch(tensor):
Args:
tensor: tensor to be watched.
"""
- pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor)
+ pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
def watch_variable(variable):
@@ -53,42 +54,39 @@ def watch_variable(variable):
Args:
variable: variable to be watched.
"""
- pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable)
+ pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
-def pop_tape():
+def pop_tape(tape):
"""Pops the top tape in the stack, if any."""
- return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop())
+ pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
@contextlib.contextmanager
def stop_recording():
- stack = []
- while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty():
- stack.append(pop_tape()._tape) # pylint: disable=protected-access
try:
+ pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
yield
finally:
- for tape in reversed(stack):
- pywrap_tensorflow.TFE_Py_TapeStackPush(tape)
+ pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
def should_record(tensors):
"""Returns true if any tape in the stack watches any of these tensors."""
- return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors)
+ return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors)
def record_operation(op_type, output_tensors, input_tensors, backward_function):
"""Records the operation on all tapes in the stack."""
- pywrap_tensorflow.TFE_Py_TapeStackRecordOperation(
+ pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
op_type, output_tensors, input_tensors, backward_function)
def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack."""
- pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id)
+ pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
def could_possibly_record():
"""Returns True if any tape is active."""
- return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty()
+ return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index d97823c17f..42e4773df3 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -26,15 +26,16 @@ limitations under the License.
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_UID;
-%rename("%s") TFE_Py_TapeStackPushNew;
-%rename("%s") TFE_Py_TapeStackPush;
-%rename("%s") TFE_Py_TapeStackPop;
-%rename("%s") TFE_Py_TapeStackIsEmpty;
-%rename("%s") TFE_Py_TapeStackShouldRecord;
-%rename("%s") TFE_Py_TapeStackWatch;
-%rename("%s") TFE_Py_TapeStackDeleteTrace;
-%rename("%s") TFE_Py_TapeStackRecordOperation;
-%rename("%s") TFE_Py_TapeStackWatchVariable;
+%rename("%s") TFE_Py_TapeSetNew;
+%rename("%s") TFE_Py_TapeSetRemove;
+%rename("%s") TFE_Py_TapeSetStopOnThread;
+%rename("%s") TFE_Py_TapeSetRestartOnThread;
+%rename("%s") TFE_Py_TapeSetIsEmpty;
+%rename("%s") TFE_Py_TapeSetShouldRecord;
+%rename("%s") TFE_Py_TapeSetWatch;
+%rename("%s") TFE_Py_TapeSetDeleteTrace;
+%rename("%s") TFE_Py_TapeSetRecordOperation;
+%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;