diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 16:17:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 16:17:26 -0700 |
commit | 7bceb06eae02094c7b6a5c4fed010c72962f43ce (patch) | |
tree | 8610a0fa5f5b8cc98f4db05ce1f2e18b922e5f87 /tensorflow/core/util | |
parent | a0c653b684afa859e7c468fea4b5e600271eec0f (diff) | |
parent | 6fdc6be324df7e3f7e3162e161ef4e869bd888fb (diff) |
Merge pull request #19403 from Intel-tensorflow:primreuse_pooling
PiperOrigin-RevId: 207625392
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 94 |
1 files changed, 59 insertions, 35 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index d90f85e422..a66b1215bd 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ #ifdef INTEL_MKL -#include <vector> +#include <memory> #include <unordered_map> #include <utility> +#include <vector> #ifdef INTEL_MKL_ML #include "mkl_dnn.h" @@ -34,11 +35,11 @@ limitations under the License. #include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" - #ifndef INTEL_MKL_ML #include "mkldnn.hpp" #include "tensorflow/core/lib/core/stringpiece.h" @@ -1503,7 +1504,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; - + /// Operations temp buffer + void* allocated_buffer_; /// CPU engine on which operation will be executed const engine* cpu_engine_; @@ -1512,6 +1514,7 @@ class MklDnnData { : user_memory_(nullptr), reorder_memory_(nullptr), op_md_(nullptr), + allocated_buffer_(nullptr), cpu_engine_(e) {} ~MklDnnData() { @@ -1652,6 +1655,14 @@ class MklDnnData { user_memory_->set_data_handle(GetTensorBuffer(tensor)); } + /// allocate function for data buffer + inline void AllocateBuffer(size_t size) { + const int64 kMemoryAlginment = 64; // For AVX512 memory alignment. + allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlginment, size); + } + + inline void* GetAllocatedBuffer() { return allocated_buffer_; } + /// Get the memory primitive for input and output of an op. If inputs /// to an op require reorders, then this function returns memory primitive /// for reorder. Otherwise, it will return memory primitive for user memory. @@ -1873,7 +1884,6 @@ class MklDnnData { net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_)); stream(stream::kind::eager).submit(net).wait(); } - }; /// Base class for operations with reuse of primitives @@ -1956,11 +1966,25 @@ class FactoryKeyCreator { } }; +static inline memory::format get_desired_format(int channel) { + memory::format fmt_desired = memory::format::any; + + if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) { + fmt_desired = memory::format::nChw16c; + } else if (port::TestCPUFeature(port::CPUFeature::AVX2) && + (channel % 8) == 0) { + fmt_desired = memory::format::nChw8c; + } else { + fmt_desired = memory::format::nchw; + } + return fmt_desired; +} + class MklReorderPrimitive : public MklPrimitive { - public: - explicit MklReorderPrimitive(const memory* from, const memory* to) { - Setup(from, to); - } + public: + explicit MklReorderPrimitive(const memory* from, const memory* to) { + Setup(from, to); + } ~MklReorderPrimitive() {} std::shared_ptr<primitive> GetPrimitive() { @@ -1972,7 +1996,7 @@ class MklReorderPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(to->get_data_handle()); } - private: + private: struct ReorderContext { std::shared_ptr<mkldnn::memory> src_mem; std::shared_ptr<mkldnn::memory> dst_mem; @@ -1996,28 +2020,27 @@ class MklReorderPrimitive : public MklPrimitive { template <typename T> class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { - public: - static MklReorderPrimitive* Get(const memory* from, - const memory* to) { - auto reorderPrim = static_cast<MklReorderPrimitive*>( + public: + static MklReorderPrimitive* Get(const memory* from, const memory* to) { + auto reorderPrim = static_cast<MklReorderPrimitive*>( MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to)); - if (reorderPrim == nullptr) { - reorderPrim = new MklReorderPrimitive(from, to); - MklReorderPrimitiveFactory<T>::GetInstance().SetReorder( - from, to, reorderPrim); - } - reorderPrim->SetMemory(from, to); - return reorderPrim; + if (reorderPrim == nullptr) { + reorderPrim = new MklReorderPrimitive(from, to); + MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to, + reorderPrim); } + reorderPrim->SetMemory(from, to); + return reorderPrim; + } static MklReorderPrimitiveFactory & GetInstance() { static MklReorderPrimitiveFactory instance_; return instance_; } - private: - MklReorderPrimitiveFactory() {}; - ~MklReorderPrimitiveFactory() {}; + private: + MklReorderPrimitiveFactory() {} + ~MklReorderPrimitiveFactory() {} static string CreateKey(const memory* from, const memory* to) { string prefix = "reorder"; @@ -2047,18 +2070,19 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { } }; - /// Fuction to find(or create) a reorder from memory pointed by from to memory pointed - /// by to, it will created primitive or get primitive from pool if it is cached. - /// Returns the primitive. - template <typename T> - inline primitive FindOrCreateReorder(const memory* from, const memory* to) { - CHECK_NOTNULL(from); - CHECK_NOTNULL(to); - MklReorderPrimitive *reorder_prim = - MklReorderPrimitiveFactory<T>::Get(from, to); - return *reorder_prim->GetPrimitive(); - } - +/// Fuction to find(or create) a reorder from memory pointed by +/// from to memory pointed by to, it will created primitive or +/// get primitive from pool if it is cached. +/// Returns the primitive. +template <typename T> +inline primitive FindOrCreateReorder(const memory* from, const memory* to) { + CHECK_NOTNULL(from); + CHECK_NOTNULL(to); + MklReorderPrimitive* reorder_prim = + MklReorderPrimitiveFactory<T>::Get(from, to); + return *reorder_prim->GetPrimitive(); +} + #endif // INTEL_MKL_DNN } // namespace tensorflow |