diff options
author | 2017-11-14 11:01:11 -0800 | |
---|---|---|
committer | 2017-11-14 12:10:39 -0800 | |
commit | a2c3dab386857cd4fe63990c6bb3aa791e3fcaf3 (patch) | |
tree | 877f252b7cdab63cf8da414134a105141daa9b51 /tensorflow/python | |
parent | c674e27bfd68a6c990e694b6afd901bfeeaa006d (diff) |
Tape stack in C++ instead of python.
PiperOrigin-RevId: 175704617
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe.h | 49 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 121 | ||||
-rw-r--r-- | tensorflow/python/eager/tape.py | 121 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 15 |
4 files changed, 153 insertions, 153 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index a67519f9a2..f96245f7a5 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -87,22 +87,36 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); -PyObject* TFE_Py_NewTape(); -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors); -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id); -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id); - -// Records an operation in the gradient tape. `tape` should point to an object -// returned by TFE_Py_NewTape. op_type is a string for the operation type, used -// in the backprop code. output_tensors should be a list of python ops.Tensor -// objects. input_tensor_ids should be a list of python integers with the ids of -// the input tensors of the recorded operation. backward_function should be the -// function to be called during backprop to, given the gradients of the output -// tensors, produce the gradients of the input tensors. -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function); +// Pushes a new tape into the thread-local stack. +void TFE_Py_TapeStackPushNew(); + +// Pops the tape from the top of the stack and returns it. +PyObject* TFE_Py_TapeStackPop(); + +// Pushes an existing tape onto the stack. +void TFE_Py_TapeStackPush(PyObject* tape); + +// Returns true if the tape stack is empty. +PyObject* TFE_Py_TapeStackIsEmpty(); + +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors); +void TFE_Py_TapeStackWatch(PyObject* tensor); +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); + +// Records an operation in the gradient tape stack.type is a string for the +// operation type, used in the backprop code. output_tensors should be a list of +// python ops.Tensor objects. input_tensor_ids should be a list of python +// integers with the ids of the input tensors of the recorded +// operation. backward_function should be the function to be called during +// backprop to, given the gradients of the output tensors, produce the gradients +// of the input tensors. +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensor_ids, + PyObject* backward_function); + +// Watches the given variable object on the given tape. +void TFE_Py_TapeStackWatchVariable(PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must // have been produced by TFE_Py_NewTape. `vspace` must be a @@ -114,9 +128,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, PyObject* output_gradients, TF_Status* status); -// Watches the given variable object on the given tape. -void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); - // Returns the set of variables watched by the given tape. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 5cb1313c4b..387eec1358 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <thread> + #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/c/c_api.h" @@ -525,12 +527,65 @@ static PyTypeObject TFE_Py_Tape_Type = { "TFE_Py_Tape objects", /* tp_doc */ }; -PyObject* TFE_Py_NewTape() { +// xcode 7 doesn't define thread_local, so for compatibility we implement our +// own. TODO(apassos) remove once we can deprecate xcode 7. +#ifndef __APPLE__ +thread_local std::vector<TFE_Py_Tape*>* tape_stack = nullptr; +std::vector<TFE_Py_Tape*>* GetTapeStack() { + if (tape_stack == nullptr) { + tape_stack = new std::vector<TFE_Py_Tape*>; + } + return tape_stack; +} +#else +static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); +static std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>* + tape_stack GUARDED_BY(stack_mu) = nullptr; +std::vector<TFE_Py_Tape*>* GetTapeStack() { + tensorflow::mutex_lock ml(stack_mu); + if (tape_stack == nullptr) { + tape_stack = + new std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>; + } + auto it = tape_stack->find(std::this_thread::get_id()); + if (it != tape_stack->end()) { + return it->second; + } + return tape_stack + ->emplace(std::this_thread::get_id(), new std::vector<TFE_Py_Tape*>) + .first->second; +} +#endif + +void TFE_Py_TapeStackPushNew() { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; + if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); tape->tape = new GradientTape(); - return reinterpret_cast<PyObject*>(tape); + GetTapeStack()->push_back(tape); +} + +void TFE_Py_TapeStackPush(PyObject* tape) { + Py_INCREF(tape); + GetTapeStack()->push_back(reinterpret_cast<TFE_Py_Tape*>(tape)); +} + +PyObject* TFE_Py_TapeStackIsEmpty() { + if (GetTapeStack()->empty()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +PyObject* TFE_Py_TapeStackPop() { + auto* stack = GetTapeStack(); + if (stack->empty()) { + PyErr_SetString(PyExc_RuntimeError, "tape stack is empty."); + return nullptr; + } + TFE_Py_Tape* top = stack->back(); + stack->pop_back(); + return reinterpret_cast<PyObject*>(top); } static std::vector<tensorflow::int64> MakeIntList(PyObject* list) { @@ -557,10 +612,14 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) { +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { if (tensors == Py_None) { Py_RETURN_FALSE; } + auto* stack = GetTapeStack(); + if (stack->empty()) { + Py_RETURN_FALSE; + } PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); if (seq == nullptr) { return nullptr; @@ -575,16 +634,22 @@ PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) { tensor_ids.push_back(FastTensorId(item)); } Py_DECREF(seq); - TFE_Py_Tape* tape = reinterpret_cast<TFE_Py_Tape*>(py_tape); - if (tape->tape->ShouldRecord(tensor_ids)) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; + for (TFE_Py_Tape* tape : *stack) { + if (tape->tape->ShouldRecord(tensor_ids)) { + Py_RETURN_TRUE; + } } + Py_RETURN_FALSE; } -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); +void TFE_Py_TapeStackWatch(PyObject* tensor) { + tensorflow::int64 tensor_id = FastTensorId(tensor); + if (PyErr_Occurred()) { + return; + } + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->Watch(tensor_id); + } } static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { @@ -646,8 +711,10 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { - reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); +void TFE_Py_TapeStackWatchVariable(PyObject* variable) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->WatchVariable(variable); + } } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { @@ -661,10 +728,14 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + auto* stack = GetTapeStack(); + if (stack->empty()) { + return; + } std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); std::vector<tensorflow::eager::TapeTensor> output_info; PyObject* seq = PySequence_Fast(output_tensors, @@ -697,14 +768,18 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, return; } - Py_INCREF(backward_function); - reinterpret_cast<TFE_Py_Tape*>(tape)->tape->RecordOperation( - op_type_str, output_info, input_ids, backward_function, - [backward_function]() { Py_DECREF(backward_function); }); + for (TFE_Py_Tape* tape : *stack) { + Py_INCREF(backward_function); + tape->tape->RecordOperation( + op_type_str, output_info, input_ids, backward_function, + [backward_function]() { Py_DECREF(backward_function); }); + } } -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id); +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->DeleteTrace(tensor_id); + } } class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> { diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index fb6b62a3e0..440c84b7ea 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -18,106 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import contextlib -import threading from tensorflow.python import pywrap_tensorflow -def tid(tensor): - return tensor._id # pylint: disable=protected-access - - -class TapeEntry( - collections.namedtuple("TapeEntry", [ - "op_type", - "output_ids", "input_ids", "backward_function", - "output_shape_and_dtype", - ])): - """Entry in the gradient tape. - - Represents the execution of one op or function, with instructions for doing - its backward pass and useful information for it. - - Args: - output_ids: tensor_id(t) for each output tensor T - input_ids: tensor_id(t) for each input tensor T - backward_function: function to be called with the downstream gradients and - side outputs as arguments which computes the backward pass. - output_shape_and_dtype: a list of (shape_tuple, dtype) for every output - tensor_id - """ - - -def _tensor_shape(t): - return t._shape_tuple() # pylint: disable=protected-access - - class Tape(object): """Represents a gradient propagation trace.""" - def __init__(self): - self._tape = pywrap_tensorflow.TFE_Py_NewTape() - - def should_record(self, tensors): - """Returns true if any tensor should be recorded. - - Args: - tensors: some tensors. - - Returns: - True if any of the tensors is in the tape. - """ - return pywrap_tensorflow.TFE_Py_TapeShouldRecord( - self._tape, tensors) - - def watch(self, tensor): - """Adds a tensor to the tape.""" - pywrap_tensorflow.TFE_Py_TapeWatch(self._tape, tid(tensor)) - - def watch_variable(self, v): - pywrap_tensorflow.TFE_Py_TapeWatchVariable(self._tape, v) + def __init__(self, tape): + self._tape = tape def watched_variables(self): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) - def record_operation(self, op_type, output_tensors, input_tensors, - backward_function): - """Records an operation in the tape.""" - pywrap_tensorflow.TFE_Py_TapeRecordOperation( - self._tape, - op_type, - output_tensors, - input_tensors, - backward_function) - - def _delete_tensor_id(self, i): - pywrap_tensorflow.TFE_Py_TapeDeleteTrace(self._tape, i) - - def delete_trace(self, tensor_id): - """Deletes any trace we have for this tensor.""" - self._delete_tensor_id(tensor_id) - - -class _TapeStack(threading.local): - - def __init__(self): - super(_TapeStack, self).__init__() - self._stack = [] - - @property - def stack(self): - return self._stack - - -# The global tape stack. -_tape_stack = _TapeStack() - def push_new_tape(): """Pushes a new tape onto the tape stack.""" - _tape_stack.stack.append(Tape()) + pywrap_tensorflow.TFE_Py_TapeStackPushNew() def watch(tensor): @@ -126,8 +44,7 @@ def watch(tensor): Args: tensor: tensor to be watched. """ - for t in _tape_stack.stack: - t.watch(tensor) + pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor) def watch_variable(variable): @@ -136,48 +53,42 @@ def watch_variable(variable): Args: variable: variable to be watched. """ - for t in _tape_stack.stack: - t.watch_variable(variable) + pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable) def pop_tape(): """Pops the top tape in the stack, if any.""" - if _tape_stack.stack: - return _tape_stack.stack.pop() - return None + return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop()) @contextlib.contextmanager def stop_recording(): - old = _tape_stack.stack - _tape_stack._stack = [] # pylint: disable=protected-access + stack = [] + while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty(): + stack.append(pop_tape()._tape) # pylint: disable=protected-access try: yield finally: - _tape_stack._stack = old # pylint: disable=protected-access + for tape in reversed(stack): + pywrap_tensorflow.TFE_Py_TapeStackPush(tape) def should_record(tensors): """Returns true if any tape in the stack watches any of these tensors.""" - if not _tape_stack.stack: - return False - return any(x.should_record(tensors) for x in _tape_stack.stack) + return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" - for t in _tape_stack.stack: - t.record_operation(op_type, output_tensors, - input_tensors, - backward_function) + pywrap_tensorflow.TFE_Py_TapeStackRecordOperation( + op_type, output_tensors, input_tensors, backward_function) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - for t in _tape_stack.stack: - t.delete_trace(tensor_id) + pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test + return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5ca0e57286..82b154164e 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -24,13 +24,16 @@ limitations under the License. %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_NewTape; -%rename("%s") TFE_Py_TapeShouldRecord; -%rename("%s") TFE_Py_TapeWatch; -%rename("%s") TFE_Py_TapeDeleteTrace; -%rename("%s") TFE_Py_TapeRecordOperation; +%rename("%s") TFE_Py_TapeStackPushNew; +%rename("%s") TFE_Py_TapeStackPush; +%rename("%s") TFE_Py_TapeStackPop; +%rename("%s") TFE_Py_TapeStackIsEmpty; +%rename("%s") TFE_Py_TapeStackShouldRecord; +%rename("%s") TFE_Py_TapeStackWatch; +%rename("%s") TFE_Py_TapeStackDeleteTrace; +%rename("%s") TFE_Py_TapeStackRecordOperation; +%rename("%s") TFE_Py_TapeStackWatchVariable; %rename("%s") TFE_Py_TapeGradient; -%rename("%s") TFE_Py_TapeWatchVariable; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; |