aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-10-05 11:27:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 11:31:17 -0700
commit1e446b37620dcdca73e855c83efcc0d14bb68a8c (patch)
tree3cee841efff64b6c02bfa7ad0639ce3e7401142b /tensorflow/python
parentb1325838aaf902e52fae4b085c6396848c445062 (diff)
Make gradient tape stack thread local
PiperOrigin-RevId: 215937618
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc20
1 files changed, 2 insertions, 18 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6193f40ce8..6d3ef9a37b 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1228,8 +1228,9 @@ static PyTypeObject TFE_Py_Tape_Type = {
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
-static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr;
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
+ thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
+ nullptr};
if (tape_set == nullptr) {
tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
}
@@ -1264,27 +1265,10 @@ class SafeTapeSet {
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
};
-// xcode 7 doesn't define thread_local, so for compatibility we implement our
-// own. TODO(apassos) remove once we can deprecate xcode 7.
-#ifndef __APPLE__
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
}
-#else
-static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
-bool* ThreadTapeIsStopped() {
- if (tape_is_stopped == nullptr) {
- tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
- }
- auto it = tape_is_stopped->find(std::this_thread::get_id());
- if (it != tape_is_stopped->end()) {
- return &(it->second);
- }
- return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
- .first->second);
-}
-#endif
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }