diff options
Diffstat (limited to 'tensorflow/core/framework/op_segment.cc')
-rw-r--r-- | tensorflow/core/framework/op_segment.cc | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc new file mode 100644 index 0000000000..a39bebd854 --- /dev/null +++ b/tensorflow/core/framework/op_segment.cc @@ -0,0 +1,86 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +OpSegment::Item::~Item() { + for (auto kv : name_kernel) delete kv.second; +} + +OpSegment::OpSegment() {} + +OpSegment::~OpSegment() { + for (auto kv : sessions_) delete kv.second; +} + +Status OpSegment::FindOrCreate(const string& session_handle, + const string& node_name, OpKernel** kernel, + CreateKernelFn create_fn) { + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); + if (*kernel != nullptr) { + return Status::OK(); + } + } + Status s = create_fn(kernel); + if (!s.ok()) { + LOG(ERROR) << "Create kernel failed: " << s; + return s; + } + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + OpKernel** p_kernel = &(item->name_kernel[node_name]); + if (*p_kernel == nullptr) { + *p_kernel = *kernel; // Inserts 'kernel' in the map. + } else { + delete *kernel; + *kernel = *p_kernel; + } + } + return Status::OK(); +} + +void OpSegment::AddHold(const string& session_handle) { + mutex_lock l(mu_); + Item** item = &sessions_[session_handle]; + if (*item == nullptr) { + *item = new Item; // num_holds == 1 + } else { + ++((*item)->num_holds); + } +} + +void OpSegment::RemoveHold(const string& session_handle) { + Item* item = nullptr; + { + mutex_lock l(mu_); + auto siter = sessions_.find(session_handle); + if (siter == sessions_.end()) { + VLOG(1) << "Session " << session_handle << " is not found."; + return; + } + item = siter->second; + if (--(item->num_holds) > 0) { + return; + } else { + sessions_.erase(siter); + } + } + delete item; +} + +} // end namespace tensorflow |