diff options
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 168 |
1 files changed, 167 insertions, 1 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 96944f27cd..bb447e0393 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1487,6 +1487,8 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, return memory::desc(md); } +template <typename T> +inline primitive FindOrCreateReorder(const memory* from, const memory* to); /* * Class to represent all the resources corresponding to a tensor in TensorFlow * that are required to execute an operation (such as Convolution). @@ -1733,6 +1735,24 @@ class MklDnnData { return false; } + /// TODO: this is a faster path with reorder primitive cache compared with + /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove + /// slow path in the future + inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) { + CHECK_NOTNULL(user_memory_); + if (IsReorderNeeded(op_pd)) { + // TODO(nhasabni): can we remove dynamic memory allocation? + // primitive reuse don't allow two same reorder prim in + // one stream, so submit it immediately + reorder_memory_ = new memory(op_pd); + std::vector<primitive> net; + net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_)); + stream(stream::kind::eager).submit(net).wait(); + return true; + } + return false; + } + /// Overloaded version of above function that accepts memory buffer /// where output of reorder needs to be stored. /// @@ -1758,6 +1778,26 @@ class MklDnnData { return false; } + /// TODO: this is a faster path with reorder primitive cache compared with + /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove + /// slow path in the future + inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd, + void* reorder_data_handle) { + CHECK_NOTNULL(reorder_data_handle); + CHECK_NOTNULL(user_memory_); + if (IsReorderNeeded(op_pd)) { + // TODO(nhasabni): can we remove dynamic memory allocation? + // primitive reuse don't allow two same reorder prim in + // one stream, so submit it immediately + std::vector<primitive> net; + reorder_memory_ = new memory(op_pd, reorder_data_handle); + net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_)); + stream(stream::kind::eager).submit(net).wait(); + return true; + } + return false; + } + /// Another overloaded version of CheckReorderToOpMem that accepts Tensor /// where output of reorder needs to be stored. /// @@ -1776,6 +1816,15 @@ class MklDnnData { return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net); } + /// TODO: this is a faster path with reorder primitive cache compared with + /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove + /// slow path in the future + inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd, + Tensor* reorder_tensor) { + CHECK_NOTNULL(reorder_tensor); + return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor)); + } + /// Function to handle output reorder /// /// This function performs very similar functionality as input reordering @@ -1812,6 +1861,20 @@ class MklDnnData { CHECK_NOTNULL(reorder_memory_); net->push_back(CreateReorder(reorder_memory_, user_memory_)); } + + /// TODO: this is a faster path with reorder primitive cache compared with + /// InsertReorderToUserMem(std::vector<primitive>* net), will remove + /// slow path in the future + inline void InsertReorderToUserMem() { + CHECK_NOTNULL(user_memory_); + CHECK_NOTNULL(reorder_memory_); + // primitive reuse don't allow two same reorder prim in + // one stream, so submit it immediately + std::vector<primitive> net; + net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_)); + stream(stream::kind::eager).submit(net).wait(); + } + }; /// Base class for operations with reuse of primitives @@ -1851,7 +1914,7 @@ class MklPrimitiveFactory { } private: - static inline std::unordered_map<std::string, MklPrimitive*> &GetHashMap() { + static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() { static thread_local std::unordered_map<std::string, MklPrimitive*> map_; return map_; } @@ -1894,6 +1957,109 @@ class FactoryKeyCreator { } }; +class MklReorderPrimitive : public MklPrimitive { + public: + explicit MklReorderPrimitive(const memory* from, const memory* to) { + Setup(from, to); + } + ~MklReorderPrimitive() {} + + std::shared_ptr<primitive> GetPrimitive() { + return context_.reorder_prim; + } + + void SetMemory(const memory* from, const memory* to) { + context_.src_mem->set_data_handle(from->get_data_handle()); + context_.dst_mem->set_data_handle(to->get_data_handle()); + } + + private: + struct ReorderContext { + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> dst_mem; + std::shared_ptr<primitive> reorder_prim; + ReorderContext(): + src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) { + } + } context_; + + engine cpu_engine_ = engine(engine::cpu, 0); + + void Setup(const memory* from, const memory* to) { + context_.src_mem.reset(new memory( + {from->get_primitive_desc().desc(), cpu_engine_}, DummyData)); + context_.dst_mem.reset(new memory( + {to->get_primitive_desc().desc(), cpu_engine_}, DummyData)); + context_.reorder_prim = std::make_shared<mkldnn::reorder>( + reorder(*context_.src_mem, *context_.dst_mem)); + } +}; + +template <typename T> +class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklReorderPrimitive* Get(const memory* from, + const memory* to) { + auto reorderPrim = static_cast<MklReorderPrimitive*>( + MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to)); + if (reorderPrim == nullptr) { + reorderPrim = new MklReorderPrimitive(from, to); + MklReorderPrimitiveFactory<T>::GetInstance().SetReorder( + from, to, reorderPrim); + } + reorderPrim->SetMemory(from, to); + return reorderPrim; + } + + static MklReorderPrimitiveFactory & GetInstance() { + static MklReorderPrimitiveFactory instance_; + return instance_; + } + + private: + MklReorderPrimitiveFactory() {}; + ~MklReorderPrimitiveFactory() {}; + + static std::string CreateKey(const memory* from, const memory* to) { + std::string prefix = "reorder"; + FactoryKeyCreator key_creator; + auto const &from_desc = from->get_primitive_desc().desc().data; + auto const &to_desc = to->get_primitive_desc().desc().data; + memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]); + memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]); + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(static_cast<int>(from_desc.format)); + key_creator.AddAsKey(static_cast<int>(from_desc.data_type)); + key_creator.AddAsKey(from_dims); + key_creator.AddAsKey(static_cast<int>(to_desc.format)); + key_creator.AddAsKey(static_cast<int>(to_desc.data_type)); + key_creator.AddAsKey(to_dims); + return key_creator.GetKey(); + } + + MklPrimitive* GetReorder(const memory* from, const memory* to) { + std::string key = CreateKey(from, to); + return this->GetOp(key); + } + + void SetReorder(const memory* from, const memory* to, MklPrimitive* op) { + std::string key = CreateKey(from, to); + this->SetOp(key, op); + } +}; + + /// Fuction to find(or create) a reorder from memory pointed by from to memory pointed + /// by to, it will created primitive or get primitive from pool if it is cached. + /// Returns the primitive. + template <typename T> + inline primitive FindOrCreateReorder(const memory* from, const memory* to) { + CHECK_NOTNULL(from); + CHECK_NOTNULL(to); + MklReorderPrimitive *reorder_prim = + MklReorderPrimitiveFactory<T>::Get(from, to); + return *reorder_prim->GetPrimitive(); + } + #endif // INTEL_MKL_DNN } // namespace tensorflow |