aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-14 11:01:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 12:10:39 -0800
commita2c3dab386857cd4fe63990c6bb3aa791e3fcaf3 (patch)
tree877f252b7cdab63cf8da414134a105141daa9b51 /tensorflow/python
parentc674e27bfd68a6c990e694b6afd901bfeeaa006d (diff)
Tape stack in C++ instead of python.
PiperOrigin-RevId: 175704617
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h49
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc121
-rw-r--r--tensorflow/python/eager/tape.py121
-rw-r--r--tensorflow/python/pywrap_tfe.i15
4 files changed, 153 insertions, 153 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a67519f9a2..f96245f7a5 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -87,22 +87,36 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// newly created type, or nullptr on error.
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
-PyObject* TFE_Py_NewTape();
-PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors);
-void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id);
-void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id);
-
-// Records an operation in the gradient tape. `tape` should point to an object
-// returned by TFE_Py_NewTape. op_type is a string for the operation type, used
-// in the backprop code. output_tensors should be a list of python ops.Tensor
-// objects. input_tensor_ids should be a list of python integers with the ids of
-// the input tensors of the recorded operation. backward_function should be the
-// function to be called during backprop to, given the gradients of the output
-// tensors, produce the gradients of the input tensors.
-void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
- PyObject* output_tensors,
- PyObject* input_tensor_ids,
- PyObject* backward_function);
+// Pushes a new tape into the thread-local stack.
+void TFE_Py_TapeStackPushNew();
+
+// Pops the tape from the top of the stack and returns it.
+PyObject* TFE_Py_TapeStackPop();
+
+// Pushes an existing tape onto the stack.
+void TFE_Py_TapeStackPush(PyObject* tape);
+
+// Returns true if the tape stack is empty.
+PyObject* TFE_Py_TapeStackIsEmpty();
+
+PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors);
+void TFE_Py_TapeStackWatch(PyObject* tensor);
+void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
+
+// Records an operation in the gradient tape stack.type is a string for the
+// operation type, used in the backprop code. output_tensors should be a list of
+// python ops.Tensor objects. input_tensor_ids should be a list of python
+// integers with the ids of the input tensors of the recorded
+// operation. backward_function should be the function to be called during
+// backprop to, given the gradients of the output tensors, produce the gradients
+// of the input tensors.
+void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
+ PyObject* output_tensors,
+ PyObject* input_tensor_ids,
+ PyObject* backward_function);
+
+// Watches the given variable object on the given tape.
+void TFE_Py_TapeStackWatchVariable(PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
// have been produced by TFE_Py_NewTape. `vspace` must be a
@@ -114,9 +128,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
PyObject* target, PyObject* sources,
PyObject* output_gradients, TF_Status* status);
-// Watches the given variable object on the given tape.
-void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
-
// Returns the set of variables watched by the given tape.
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 5cb1313c4b..387eec1358 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <thread>
+
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
@@ -525,12 +527,65 @@ static PyTypeObject TFE_Py_Tape_Type = {
"TFE_Py_Tape objects", /* tp_doc */
};
-PyObject* TFE_Py_NewTape() {
+// xcode 7 doesn't define thread_local, so for compatibility we implement our
+// own. TODO(apassos) remove once we can deprecate xcode 7.
+#ifndef __APPLE__
+thread_local std::vector<TFE_Py_Tape*>* tape_stack = nullptr;
+std::vector<TFE_Py_Tape*>* GetTapeStack() {
+ if (tape_stack == nullptr) {
+ tape_stack = new std::vector<TFE_Py_Tape*>;
+ }
+ return tape_stack;
+}
+#else
+static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED);
+static std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>*
+ tape_stack GUARDED_BY(stack_mu) = nullptr;
+std::vector<TFE_Py_Tape*>* GetTapeStack() {
+ tensorflow::mutex_lock ml(stack_mu);
+ if (tape_stack == nullptr) {
+ tape_stack =
+ new std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>;
+ }
+ auto it = tape_stack->find(std::this_thread::get_id());
+ if (it != tape_stack->end()) {
+ return it->second;
+ }
+ return tape_stack
+ ->emplace(std::this_thread::get_id(), new std::vector<TFE_Py_Tape*>)
+ .first->second;
+}
+#endif
+
+void TFE_Py_TapeStackPushNew() {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
- if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
+ if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
tape->tape = new GradientTape();
- return reinterpret_cast<PyObject*>(tape);
+ GetTapeStack()->push_back(tape);
+}
+
+void TFE_Py_TapeStackPush(PyObject* tape) {
+ Py_INCREF(tape);
+ GetTapeStack()->push_back(reinterpret_cast<TFE_Py_Tape*>(tape));
+}
+
+PyObject* TFE_Py_TapeStackIsEmpty() {
+ if (GetTapeStack()->empty()) {
+ Py_RETURN_TRUE;
+ }
+ Py_RETURN_FALSE;
+}
+
+PyObject* TFE_Py_TapeStackPop() {
+ auto* stack = GetTapeStack();
+ if (stack->empty()) {
+ PyErr_SetString(PyExc_RuntimeError, "tape stack is empty.");
+ return nullptr;
+ }
+ TFE_Py_Tape* top = stack->back();
+ stack->pop_back();
+ return reinterpret_cast<PyObject*>(top);
}
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
@@ -557,10 +612,14 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
return tensor_ids;
}
-PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) {
+PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
if (tensors == Py_None) {
Py_RETURN_FALSE;
}
+ auto* stack = GetTapeStack();
+ if (stack->empty()) {
+ Py_RETURN_FALSE;
+ }
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return nullptr;
@@ -575,16 +634,22 @@ PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) {
tensor_ids.push_back(FastTensorId(item));
}
Py_DECREF(seq);
- TFE_Py_Tape* tape = reinterpret_cast<TFE_Py_Tape*>(py_tape);
- if (tape->tape->ShouldRecord(tensor_ids)) {
- Py_RETURN_TRUE;
- } else {
- Py_RETURN_FALSE;
+ for (TFE_Py_Tape* tape : *stack) {
+ if (tape->tape->ShouldRecord(tensor_ids)) {
+ Py_RETURN_TRUE;
+ }
}
+ Py_RETURN_FALSE;
}
-void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) {
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
+void TFE_Py_TapeStackWatch(PyObject* tensor) {
+ tensorflow::int64 tensor_id = FastTensorId(tensor);
+ if (PyErr_Occurred()) {
+ return;
+ }
+ for (TFE_Py_Tape* tape : *GetTapeStack()) {
+ tape->tape->Watch(tensor_id);
+ }
}
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
@@ -646,8 +711,10 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
+void TFE_Py_TapeStackWatchVariable(PyObject* variable) {
+ for (TFE_Py_Tape* tape : *GetTapeStack()) {
+ tape->tape->WatchVariable(variable);
+ }
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
@@ -661,10 +728,14 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return result;
}
-void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
- PyObject* output_tensors,
- PyObject* input_tensors,
- PyObject* backward_function) {
+void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
+ PyObject* output_tensors,
+ PyObject* input_tensors,
+ PyObject* backward_function) {
+ auto* stack = GetTapeStack();
+ if (stack->empty()) {
+ return;
+ }
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
@@ -697,14 +768,18 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
return;
}
- Py_INCREF(backward_function);
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->RecordOperation(
- op_type_str, output_info, input_ids, backward_function,
- [backward_function]() { Py_DECREF(backward_function); });
+ for (TFE_Py_Tape* tape : *stack) {
+ Py_INCREF(backward_function);
+ tape->tape->RecordOperation(
+ op_type_str, output_info, input_ids, backward_function,
+ [backward_function]() { Py_DECREF(backward_function); });
+ }
}
-void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
+void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) {
+ for (TFE_Py_Tape* tape : *GetTapeStack()) {
+ tape->tape->DeleteTrace(tensor_id);
+ }
}
class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index fb6b62a3e0..440c84b7ea 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -18,106 +18,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
import contextlib
-import threading
from tensorflow.python import pywrap_tensorflow
-def tid(tensor):
- return tensor._id # pylint: disable=protected-access
-
-
-class TapeEntry(
- collections.namedtuple("TapeEntry", [
- "op_type",
- "output_ids", "input_ids", "backward_function",
- "output_shape_and_dtype",
- ])):
- """Entry in the gradient tape.
-
- Represents the execution of one op or function, with instructions for doing
- its backward pass and useful information for it.
-
- Args:
- output_ids: tensor_id(t) for each output tensor T
- input_ids: tensor_id(t) for each input tensor T
- backward_function: function to be called with the downstream gradients and
- side outputs as arguments which computes the backward pass.
- output_shape_and_dtype: a list of (shape_tuple, dtype) for every output
- tensor_id
- """
-
-
-def _tensor_shape(t):
- return t._shape_tuple() # pylint: disable=protected-access
-
-
class Tape(object):
"""Represents a gradient propagation trace."""
- def __init__(self):
- self._tape = pywrap_tensorflow.TFE_Py_NewTape()
-
- def should_record(self, tensors):
- """Returns true if any tensor should be recorded.
-
- Args:
- tensors: some tensors.
-
- Returns:
- True if any of the tensors is in the tape.
- """
- return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
- self._tape, tensors)
-
- def watch(self, tensor):
- """Adds a tensor to the tape."""
- pywrap_tensorflow.TFE_Py_TapeWatch(self._tape, tid(tensor))
-
- def watch_variable(self, v):
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(self._tape, v)
+ def __init__(self, tape):
+ self._tape = tape
def watched_variables(self):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
- def record_operation(self, op_type, output_tensors, input_tensors,
- backward_function):
- """Records an operation in the tape."""
- pywrap_tensorflow.TFE_Py_TapeRecordOperation(
- self._tape,
- op_type,
- output_tensors,
- input_tensors,
- backward_function)
-
- def _delete_tensor_id(self, i):
- pywrap_tensorflow.TFE_Py_TapeDeleteTrace(self._tape, i)
-
- def delete_trace(self, tensor_id):
- """Deletes any trace we have for this tensor."""
- self._delete_tensor_id(tensor_id)
-
-
-class _TapeStack(threading.local):
-
- def __init__(self):
- super(_TapeStack, self).__init__()
- self._stack = []
-
- @property
- def stack(self):
- return self._stack
-
-
-# The global tape stack.
-_tape_stack = _TapeStack()
-
def push_new_tape():
"""Pushes a new tape onto the tape stack."""
- _tape_stack.stack.append(Tape())
+ pywrap_tensorflow.TFE_Py_TapeStackPushNew()
def watch(tensor):
@@ -126,8 +44,7 @@ def watch(tensor):
Args:
tensor: tensor to be watched.
"""
- for t in _tape_stack.stack:
- t.watch(tensor)
+ pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor)
def watch_variable(variable):
@@ -136,48 +53,42 @@ def watch_variable(variable):
Args:
variable: variable to be watched.
"""
- for t in _tape_stack.stack:
- t.watch_variable(variable)
+ pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable)
def pop_tape():
"""Pops the top tape in the stack, if any."""
- if _tape_stack.stack:
- return _tape_stack.stack.pop()
- return None
+ return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop())
@contextlib.contextmanager
def stop_recording():
- old = _tape_stack.stack
- _tape_stack._stack = [] # pylint: disable=protected-access
+ stack = []
+ while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty():
+ stack.append(pop_tape()._tape) # pylint: disable=protected-access
try:
yield
finally:
- _tape_stack._stack = old # pylint: disable=protected-access
+ for tape in reversed(stack):
+ pywrap_tensorflow.TFE_Py_TapeStackPush(tape)
def should_record(tensors):
"""Returns true if any tape in the stack watches any of these tensors."""
- if not _tape_stack.stack:
- return False
- return any(x.should_record(tensors) for x in _tape_stack.stack)
+ return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors)
def record_operation(op_type, output_tensors, input_tensors, backward_function):
"""Records the operation on all tapes in the stack."""
- for t in _tape_stack.stack:
- t.record_operation(op_type, output_tensors,
- input_tensors,
- backward_function)
+ pywrap_tensorflow.TFE_Py_TapeStackRecordOperation(
+ op_type, output_tensors, input_tensors, backward_function)
def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack."""
- for t in _tape_stack.stack:
- t.delete_trace(tensor_id)
+ pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id)
def could_possibly_record():
"""Returns True if any tape is active."""
- return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test
+ return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5ca0e57286..82b154164e 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -24,13 +24,16 @@ limitations under the License.
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_UID;
-%rename("%s") TFE_Py_NewTape;
-%rename("%s") TFE_Py_TapeShouldRecord;
-%rename("%s") TFE_Py_TapeWatch;
-%rename("%s") TFE_Py_TapeDeleteTrace;
-%rename("%s") TFE_Py_TapeRecordOperation;
+%rename("%s") TFE_Py_TapeStackPushNew;
+%rename("%s") TFE_Py_TapeStackPush;
+%rename("%s") TFE_Py_TapeStackPop;
+%rename("%s") TFE_Py_TapeStackIsEmpty;
+%rename("%s") TFE_Py_TapeStackShouldRecord;
+%rename("%s") TFE_Py_TapeStackWatch;
+%rename("%s") TFE_Py_TapeStackDeleteTrace;
+%rename("%s") TFE_Py_TapeStackRecordOperation;
+%rename("%s") TFE_Py_TapeStackWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
-%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;