diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/framework/op_segment.h |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/framework/op_segment.h')
-rw-r--r-- | tensorflow/core/framework/op_segment.h | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h new file mode 100644 index 0000000000..55249d2a38 --- /dev/null +++ b/tensorflow/core/framework/op_segment.h @@ -0,0 +1,67 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ +#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ + +#include <string> +#include <unordered_map> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// OpSegment keeps track of OpKernels registered for sessions running +// on a device. +// +// The implementation maintains a two-level map. The 1st level maps +// session handle to the map of registered OpKernels. The 2nd level +// map maps node names to instantiated OpKernel objects. +// +// Each 2-nd level map is reference-counted and the caller can call +// AddHold to obtain a reference on all kernels of a session and +// ensure these kernels are alive until a corresponding RemoveHold is +// called on the same session. +class OpSegment { + public: + OpSegment(); + ~OpSegment(); + + // A hold can be placed on a session, preventing all its kernels + // from being deleted. + void AddHold(const string& session_handle); + void RemoveHold(const string& session_handle); + + // If the kernel for "node_name" has been created in the + // "session_handle", returns the existing op kernel in "*kernel". + // Otherwise, creates the kernel by calling create_fn(), cache it, + // and returns it in "*kernel". If create_fn() fails, returns the + // error. + // + // OpSegment keeps the ownership of the returned "*kernel". + typedef std::function<Status(OpKernel**)> CreateKernelFn; + Status FindOrCreate(const string& session_handle, const string& node_name, + OpKernel** kernel, CreateKernelFn create_fn); + + private: + // op name -> OpKernel + typedef std::unordered_map<string, OpKernel*> KernelMap; + struct Item { + int num_holds = 1; // Num of holds put on the session. + KernelMap name_kernel; // op name -> kernel. + ~Item(); + }; + + // session handle -> item. + // Session handles are produced by strings::FpToString() + typedef std::unordered_map<string, Item*> SessionMap; + + mutable mutex mu_; + SessionMap sessions_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(OpSegment); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ |