From f5e2edb5fc84179637355c727c4f0953764b48e5 Mon Sep 17 00:00:00 2001 From: Guozhong Zhuang Date: Fri, 18 May 2018 11:44:15 -0700 Subject: enhancement with pooling ops primitive reuse --- tensorflow/core/kernels/mkl_pooling_ops_common.h | 312 ++++++++++++++++++++++- 1 file changed, 311 insertions(+), 1 deletion(-) (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h') diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index 279167aba2..468dc41c57 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -19,6 +19,7 @@ limitations under the License. #ifdef INTEL_MKL #include #include +#include #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" @@ -32,6 +33,315 @@ using mkldnn::stream; namespace tensorflow { +#ifndef INTEL_MKL_ML + +using mkldnn::memory; +using mkldnn::pooling_max; +using mkldnn::pooling_avg; +using mkldnn::pooling_avg_include_padding; +using mkldnn::pooling_avg_exclude_padding; +using mkldnn::prop_kind; + +struct MklPoolingParams { + memory::dims src_dims; + memory::dims dst_dims; + memory::dims filter_dims; + memory::dims strides; + memory::dims padding_left; + memory::dims padding_right; + mkldnn::algorithm alg_kind; + + MklPoolingParams(memory::dims src_dims, + memory::dims dst_dims, memory::dims filter_dims, + memory::dims strides, memory::dims padding_left, + memory::dims padding_right, mkldnn::algorithm alg_kind) : + src_dims(src_dims), dst_dims(dst_dims), + filter_dims(filter_dims), strides(strides), + padding_left(padding_left), padding_right(padding_right), + alg_kind(alg_kind) { + } +}; + +template +class MklPoolingFwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) { + context_.fwd_stream.reset(new stream(stream::kind::eager)); + if (context_.fwd == nullptr) + Setup(fwdParams); + } + + ~MklPoolingFwdPrimitive() {} + + // Pooling forward execute + // src_data: input data buffer of src + // ws_data: input data buffer of workspace + // dst_data: output data buffer of dst + void Execute(const T* src_data, const T* dst_data, + const void* ws_data = nullptr); + + std::shared_ptr + GetPoolingFwdPd() const { + return context_.fwd_pd; + } + + memory::format GetSrcMemoryFormat() const { + return context_.src_fmt; + } + + memory::format GetDstMemoryFormat() const { + return context_.dst_fmt; + } + + private: + void Setup(const MklPoolingParams& fwdParams); + + + struct PoolingFwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + memory::format src_fmt; + memory::format dst_fmt; + memory::format ws_fmt; + + // workspace shape + memory::dims ws_dims; + memory::data_type ws_dt; + size_t ws_size; + + // MKL-DNN memory, just dummy data + std::shared_ptr ws_mem; + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // desc & primitive desc + std::shared_ptr fwd_desc; + std::shared_ptr fwd_pd; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // Pooling primitive + std::shared_ptr fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + PoolingFwdContext() : + src_fmt(memory::format::any), dst_fmt(memory::format::any), + ws_fmt(memory::format::any), ws_mem(nullptr), src_mem(nullptr), + dst_mem(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), + dst_md(nullptr), fwd(nullptr), fwd_stream(nullptr) { + } + } context_; + + engine cpu_engine_ = engine(engine::cpu, 0); +}; + +template +class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingFwdPrimitive* Get(const MklPoolingParams& fwdParams) { + MklPoolingFwdPrimitive* pooling_forward = nullptr; + + // Get pooling primitive from the pool + pooling_forward = static_cast*>( + MklPoolingFwdPrimitiveFactory::GetInstance().GetPoolingFwd(fwdParams)); + + if (pooling_forward == nullptr) { + pooling_forward = new MklPoolingFwdPrimitive(fwdParams); + MklPoolingFwdPrimitiveFactory::GetInstance().SetPoolingFwd( + fwdParams, pooling_forward); + } + return pooling_forward; + } + + static MklPoolingFwdPrimitiveFactory& GetInstance() { + static MklPoolingFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingFwdPrimitiveFactory() {} + ~MklPoolingFwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& fwdParams) { + std::string prefix = "pooling_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.dst_dims); + key_creator.AddAsKey(fwdParams.filter_dims); + key_creator.AddAsKey(fwdParams.strides); + key_creator.AddAsKey(fwdParams.padding_left); + key_creator.AddAsKey(fwdParams.padding_right); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive *op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + + +template +class MklPoolingBwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) { + context_.bwd_stream.reset(new stream(stream::kind::eager)); + if (context_.bwd == nullptr) + Setup(bwdParams); + } + + ~MklPoolingBwdPrimitive() {} + + // Pooling backward execute + // diff_dst_data: input data buffer of diff_dst + // diff_src_data: output data buffer of diff_src + // ws_data: input data buffer of workspace + void Execute(const T* diff_dst_data, const T* diff_src_data, + const void* ws_data = nullptr); + + public: + std::shared_ptr + GetPoolingFwdPd() const { + return context_.fwd_pd; + } + std::shared_ptr + GetPoolingBwdPd() const { + return context_.bwd_pd; + } + + memory::format GetDiffDstFormat() const { + return context_.diff_dst_fmt; + } + + mkldnn::memory::data_type GetWorkspaceDataType() const { + return context_.ws_dt; + } + memory::format GetWorkspaceFormat() const { + return context_.ws_fmt; + } + + private: + void Setup(const MklPoolingParams& bwdParams); + // Primitive reuse context for pooling bwd ops + struct PoolingBwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + mkldnn::memory::format diff_src_fmt; + mkldnn::memory::format diff_dst_fmt; + mkldnn::memory::format ws_fmt; + + // workspace attribute + mkldnn::memory::dims ws_dims; + mkldnn::memory::data_type ws_dt; + + // MKL-DNN memory + std::shared_ptr ws_mem; + std::shared_ptr diff_src_mem; + std::shared_ptr diff_dst_mem; + + // memory desc + std::shared_ptr diff_src_md; + std::shared_ptr diff_dst_md; + + // desc & primitive desc + std::shared_ptr fwd_desc; + std::shared_ptr bwd_desc; + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; + + // pooling primitive + std::shared_ptr bwd; + std::shared_ptr bwd_stream; + + std::vector bwd_primitives; + + PoolingBwdContext() : + diff_src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), + ws_fmt(memory::format::any), ws_mem(nullptr), diff_src_mem(nullptr), + diff_dst_mem(nullptr), diff_src_md(nullptr), diff_dst_md(nullptr), + fwd_desc(nullptr), bwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), + bwd(nullptr), bwd_stream(nullptr) { + } + } context_; + // cpu engine + engine cpu_engine = engine(engine::cpu, 0); +}; + +template +class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingBwdPrimitive *Get(const MklPoolingParams& bwdParams) { + MklPoolingBwdPrimitive* pooling_backward = nullptr; + + // Find a pooling backward primitive from the pool + // If it does not exist, create a new one + pooling_backward = static_cast*>( + MklPoolingBwdPrimitiveFactory::GetInstance().GetPoolingBwd(bwdParams)); + if (pooling_backward == nullptr) { + pooling_backward = new MklPoolingBwdPrimitive(bwdParams); + MklPoolingBwdPrimitiveFactory::GetInstance().SetPoolingBwd( + bwdParams, pooling_backward); + } + return pooling_backward; + } + + static MklPoolingBwdPrimitiveFactory& GetInstance() { + static MklPoolingBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingBwdPrimitiveFactory() {} + ~MklPoolingBwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& bwdParams) { + std::string prefix = "pooling_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.dst_dims); + key_creator.AddAsKey(bwdParams.filter_dims); + key_creator.AddAsKey(bwdParams.strides); + key_creator.AddAsKey(bwdParams.padding_left); + key_creator.AddAsKey(bwdParams.padding_right); + key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive *op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; +#endif + typedef Eigen::ThreadPoolDevice CPUDevice; struct MklPoolParameters { @@ -351,7 +661,7 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase { memory::desc ConfigureOriginalOutput( const MklPoolParameters& pool_params, const MklDnnShape& original_output_mkl_shape, - memory::dims output_dims_mkl_order) { + const memory::dims& output_dims_mkl_order) { this->GetOutputDims(pool_params, &output_dims_mkl_order); return original_output_mkl_shape.IsMklTensor() -- cgit v1.2.3