aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_segment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_segment.cc')
-rw-r--r--tensorflow/core/framework/op_segment.cc86
1 files changed, 86 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
new file mode 100644
index 0000000000..a39bebd854
--- /dev/null
+++ b/tensorflow/core/framework/op_segment.cc
@@ -0,0 +1,86 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+OpSegment::Item::~Item() {
+ for (auto kv : name_kernel) delete kv.second;
+}
+
+OpSegment::OpSegment() {}
+
+OpSegment::~OpSegment() {
+ for (auto kv : sessions_) delete kv.second;
+}
+
+Status OpSegment::FindOrCreate(const string& session_handle,
+ const string& node_name, OpKernel** kernel,
+ CreateKernelFn create_fn) {
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
+ if (*kernel != nullptr) {
+ return Status::OK();
+ }
+ }
+ Status s = create_fn(kernel);
+ if (!s.ok()) {
+ LOG(ERROR) << "Create kernel failed: " << s;
+ return s;
+ }
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ OpKernel** p_kernel = &(item->name_kernel[node_name]);
+ if (*p_kernel == nullptr) {
+ *p_kernel = *kernel; // Inserts 'kernel' in the map.
+ } else {
+ delete *kernel;
+ *kernel = *p_kernel;
+ }
+ }
+ return Status::OK();
+}
+
+void OpSegment::AddHold(const string& session_handle) {
+ mutex_lock l(mu_);
+ Item** item = &sessions_[session_handle];
+ if (*item == nullptr) {
+ *item = new Item; // num_holds == 1
+ } else {
+ ++((*item)->num_holds);
+ }
+}
+
+void OpSegment::RemoveHold(const string& session_handle) {
+ Item* item = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto siter = sessions_.find(session_handle);
+ if (siter == sessions_.end()) {
+ VLOG(1) << "Session " << session_handle << " is not found.";
+ return;
+ }
+ item = siter->second;
+ if (--(item->num_holds) > 0) {
+ return;
+ } else {
+ sessions_.erase(siter);
+ }
+ }
+ delete item;
+}
+
+} // end namespace tensorflow