aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/mkl_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r--tensorflow/core/util/mkl_util.h168
1 files changed, 167 insertions, 1 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 96944f27cd..bb447e0393 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -1487,6 +1487,8 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
return memory::desc(md);
}
+template <typename T>
+inline primitive FindOrCreateReorder(const memory* from, const memory* to);
/*
* Class to represent all the resources corresponding to a tensor in TensorFlow
* that are required to execute an operation (such as Convolution).
@@ -1733,6 +1735,24 @@ class MklDnnData {
return false;
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
+ CHECK_NOTNULL(user_memory_);
+ if (IsReorderNeeded(op_pd)) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ reorder_memory_ = new memory(op_pd);
+ std::vector<primitive> net;
+ net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ return true;
+ }
+ return false;
+ }
+
/// Overloaded version of above function that accepts memory buffer
/// where output of reorder needs to be stored.
///
@@ -1758,6 +1778,26 @@ class MklDnnData {
return false;
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ void* reorder_data_handle) {
+ CHECK_NOTNULL(reorder_data_handle);
+ CHECK_NOTNULL(user_memory_);
+ if (IsReorderNeeded(op_pd)) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ std::vector<primitive> net;
+ reorder_memory_ = new memory(op_pd, reorder_data_handle);
+ net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ return true;
+ }
+ return false;
+ }
+
/// Another overloaded version of CheckReorderToOpMem that accepts Tensor
/// where output of reorder needs to be stored.
///
@@ -1776,6 +1816,15 @@ class MklDnnData {
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ Tensor* reorder_tensor) {
+ CHECK_NOTNULL(reorder_tensor);
+ return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
+ }
+
/// Function to handle output reorder
///
/// This function performs very similar functionality as input reordering
@@ -1812,6 +1861,20 @@ class MklDnnData {
CHECK_NOTNULL(reorder_memory_);
net->push_back(CreateReorder(reorder_memory_, user_memory_));
}
+
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// InsertReorderToUserMem(std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline void InsertReorderToUserMem() {
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(reorder_memory_);
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ std::vector<primitive> net;
+ net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
};
/// Base class for operations with reuse of primitives
@@ -1851,7 +1914,7 @@ class MklPrimitiveFactory {
}
private:
- static inline std::unordered_map<std::string, MklPrimitive*> &GetHashMap() {
+ static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
return map_;
}
@@ -1894,6 +1957,109 @@ class FactoryKeyCreator {
}
};
+class MklReorderPrimitive : public MklPrimitive {
+ public:
+ explicit MklReorderPrimitive(const memory* from, const memory* to) {
+ Setup(from, to);
+ }
+ ~MklReorderPrimitive() {}
+
+ std::shared_ptr<primitive> GetPrimitive() {
+ return context_.reorder_prim;
+ }
+
+ void SetMemory(const memory* from, const memory* to) {
+ context_.src_mem->set_data_handle(from->get_data_handle());
+ context_.dst_mem->set_data_handle(to->get_data_handle());
+ }
+
+ private:
+ struct ReorderContext {
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+ std::shared_ptr<primitive> reorder_prim;
+ ReorderContext():
+ src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {
+ }
+ } context_;
+
+ engine cpu_engine_ = engine(engine::cpu, 0);
+
+ void Setup(const memory* from, const memory* to) {
+ context_.src_mem.reset(new memory(
+ {from->get_primitive_desc().desc(), cpu_engine_}, DummyData));
+ context_.dst_mem.reset(new memory(
+ {to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
+ context_.reorder_prim = std::make_shared<mkldnn::reorder>(
+ reorder(*context_.src_mem, *context_.dst_mem));
+ }
+};
+
+template <typename T>
+class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklReorderPrimitive* Get(const memory* from,
+ const memory* to) {
+ auto reorderPrim = static_cast<MklReorderPrimitive*>(
+ MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
+ if (reorderPrim == nullptr) {
+ reorderPrim = new MklReorderPrimitive(from, to);
+ MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(
+ from, to, reorderPrim);
+ }
+ reorderPrim->SetMemory(from, to);
+ return reorderPrim;
+ }
+
+ static MklReorderPrimitiveFactory & GetInstance() {
+ static MklReorderPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklReorderPrimitiveFactory() {};
+ ~MklReorderPrimitiveFactory() {};
+
+ static std::string CreateKey(const memory* from, const memory* to) {
+ std::string prefix = "reorder";
+ FactoryKeyCreator key_creator;
+ auto const &from_desc = from->get_primitive_desc().desc().data;
+ auto const &to_desc = to->get_primitive_desc().desc().data;
+ memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
+ memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(static_cast<int>(from_desc.format));
+ key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
+ key_creator.AddAsKey(from_dims);
+ key_creator.AddAsKey(static_cast<int>(to_desc.format));
+ key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
+ key_creator.AddAsKey(to_dims);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetReorder(const memory* from, const memory* to) {
+ std::string key = CreateKey(from, to);
+ return this->GetOp(key);
+ }
+
+ void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
+ std::string key = CreateKey(from, to);
+ this->SetOp(key, op);
+ }
+};
+
+ /// Fuction to find(or create) a reorder from memory pointed by from to memory pointed
+ /// by to, it will created primitive or get primitive from pool if it is cached.
+ /// Returns the primitive.
+ template <typename T>
+ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
+ CHECK_NOTNULL(from);
+ CHECK_NOTNULL(to);
+ MklReorderPrimitive *reorder_prim =
+ MklReorderPrimitiveFactory<T>::Get(from, to);
+ return *reorder_prim->GetPrimitive();
+ }
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow