diff options
author | Igor Ganichev <iga@google.com> | 2018-10-05 11:27:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 11:31:17 -0700 |
commit | 1e446b37620dcdca73e855c83efcc0d14bb68a8c (patch) | |
tree | 3cee841efff64b6c02bfa7ad0639ce3e7401142b /tensorflow/python | |
parent | b1325838aaf902e52fae4b085c6396848c445062 (diff) |
Make gradient tape stack thread local
PiperOrigin-RevId: 215937618
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 20 |
1 files changed, 2 insertions, 18 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6193f40ce8..6d3ef9a37b 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1228,8 +1228,9 @@ static PyTypeObject TFE_Py_Tape_Type = { // GIL, which is always held when any TFE_Py_* methods are called. We should // revisit this if/when decide to not hold the GIL while manipulating the tape // stack. -static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr; tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { + thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{ + nullptr}; if (tape_set == nullptr) { tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>; } @@ -1264,27 +1265,10 @@ class SafeTapeSet { tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_; }; -// xcode 7 doesn't define thread_local, so for compatibility we implement our -// own. TODO(apassos) remove once we can deprecate xcode 7. -#ifndef __APPLE__ bool* ThreadTapeIsStopped() { thread_local bool thread_tape_is_stopped{false}; return &thread_tape_is_stopped; } -#else -static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr; -bool* ThreadTapeIsStopped() { - if (tape_is_stopped == nullptr) { - tape_is_stopped = new std::unordered_map<std::thread::id, bool>; - } - auto it = tape_is_stopped->find(std::this_thread::get_id()); - if (it != tape_is_stopped->end()) { - return &(it->second); - } - return &(tape_is_stopped->emplace(std::this_thread::get_id(), false) - .first->second); -} -#endif void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } |