aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_pooling_ops_common.h
diff options
context:
space:
mode:
authorGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-05-18 11:44:15 -0700
committerGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-05-18 11:44:15 -0700
commitf5e2edb5fc84179637355c727c4f0953764b48e5 (patch)
tree38c7503b99468ec8ccd0b69175e22c263e818f95 /tensorflow/core/kernels/mkl_pooling_ops_common.h
parent856f3c9d4c36ae7f63f8cf17f05898c866e191a3 (diff)
enhancement with pooling ops primitive reuse
Diffstat (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h')
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h312
1 files changed, 311 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index 279167aba2..468dc41c57 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -19,6 +19,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include <string>
#include <vector>
+#include <memory>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
@@ -32,6 +33,315 @@ using mkldnn::stream;
namespace tensorflow {
+#ifndef INTEL_MKL_ML
+
+using mkldnn::memory;
+using mkldnn::pooling_max;
+using mkldnn::pooling_avg;
+using mkldnn::pooling_avg_include_padding;
+using mkldnn::pooling_avg_exclude_padding;
+using mkldnn::prop_kind;
+
+struct MklPoolingParams {
+ memory::dims src_dims;
+ memory::dims dst_dims;
+ memory::dims filter_dims;
+ memory::dims strides;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ mkldnn::algorithm alg_kind;
+
+ MklPoolingParams(memory::dims src_dims,
+ memory::dims dst_dims, memory::dims filter_dims,
+ memory::dims strides, memory::dims padding_left,
+ memory::dims padding_right, mkldnn::algorithm alg_kind) :
+ src_dims(src_dims), dst_dims(dst_dims),
+ filter_dims(filter_dims), strides(strides),
+ padding_left(padding_left), padding_right(padding_right),
+ alg_kind(alg_kind) {
+ }
+};
+
+template <typename T>
+class MklPoolingFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) {
+ context_.fwd_stream.reset(new stream(stream::kind::eager));
+ if (context_.fwd == nullptr)
+ Setup(fwdParams);
+ }
+
+ ~MklPoolingFwdPrimitive() {}
+
+ // Pooling forward execute
+ // src_data: input data buffer of src
+ // ws_data: input data buffer of workspace
+ // dst_data: output data buffer of dst
+ void Execute(const T* src_data, const T* dst_data,
+ const void* ws_data = nullptr);
+
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc>
+ GetPoolingFwdPd() const {
+ return context_.fwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() const {
+ return context_.src_fmt;
+ }
+
+ memory::format GetDstMemoryFormat() const {
+ return context_.dst_fmt;
+ }
+
+ private:
+ void Setup(const MklPoolingParams& fwdParams);
+
+
+ struct PoolingFwdContext {
+ // algorithm
+ mkldnn::algorithm alg_kind;
+
+ // expected memory format
+ memory::format src_fmt;
+ memory::format dst_fmt;
+ memory::format ws_fmt;
+
+ // workspace shape
+ memory::dims ws_dims;
+ memory::data_type ws_dt;
+ size_t ws_size;
+
+ // MKL-DNN memory, just dummy data
+ std::shared_ptr<mkldnn::memory> ws_mem;
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+
+ // desc & primitive desc
+ std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
+
+ // memory desc
+ std::shared_ptr<mkldnn::memory::desc> src_md;
+ std::shared_ptr<mkldnn::memory::desc> dst_md;
+
+ // Pooling primitive
+ std::shared_ptr<mkldnn::pooling_forward> fwd;
+ std::shared_ptr<mkldnn::stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ PoolingFwdContext() :
+ src_fmt(memory::format::any), dst_fmt(memory::format::any),
+ ws_fmt(memory::format::any), ws_mem(nullptr), src_mem(nullptr),
+ dst_mem(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr),
+ dst_md(nullptr), fwd(nullptr), fwd_stream(nullptr) {
+ }
+ } context_;
+
+ engine cpu_engine_ = engine(engine::cpu, 0);
+};
+
+template <typename T>
+class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) {
+ MklPoolingFwdPrimitive<T>* pooling_forward = nullptr;
+
+ // Get pooling primitive from the pool
+ pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>(
+ MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd(fwdParams));
+
+ if (pooling_forward == nullptr) {
+ pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams);
+ MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd(
+ fwdParams, pooling_forward);
+ }
+ return pooling_forward;
+ }
+
+ static MklPoolingFwdPrimitiveFactory& GetInstance() {
+ static MklPoolingFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklPoolingFwdPrimitiveFactory() {}
+ ~MklPoolingFwdPrimitiveFactory() {}
+
+ // The key to be created will be used to get/set pooling
+ // primitive op from reuse perspective.
+ // A pooling key is a string which concates key parameters
+ // as well as algorithm kind (max versus avg).
+ static std::string CreateKey(const MklPoolingParams& fwdParams) {
+ std::string prefix = "pooling_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey(fwdParams.dst_dims);
+ key_creator.AddAsKey(fwdParams.filter_dims);
+ key_creator.AddAsKey(fwdParams.strides);
+ key_creator.AddAsKey(fwdParams.padding_left);
+ key_creator.AddAsKey(fwdParams.padding_right);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
+ std::string key = CreateKey(fwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive *op) {
+ std::string key = CreateKey(fwdParams);
+ this->SetOp(key, op);
+ }
+};
+
+
+template <typename T>
+class MklPoolingBwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) {
+ context_.bwd_stream.reset(new stream(stream::kind::eager));
+ if (context_.bwd == nullptr)
+ Setup(bwdParams);
+ }
+
+ ~MklPoolingBwdPrimitive() {}
+
+ // Pooling backward execute
+ // diff_dst_data: input data buffer of diff_dst
+ // diff_src_data: output data buffer of diff_src
+ // ws_data: input data buffer of workspace
+ void Execute(const T* diff_dst_data, const T* diff_src_data,
+ const void* ws_data = nullptr);
+
+ public:
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc>
+ GetPoolingFwdPd() const {
+ return context_.fwd_pd;
+ }
+ std::shared_ptr<mkldnn::pooling_backward::primitive_desc>
+ GetPoolingBwdPd() const {
+ return context_.bwd_pd;
+ }
+
+ memory::format GetDiffDstFormat() const {
+ return context_.diff_dst_fmt;
+ }
+
+ mkldnn::memory::data_type GetWorkspaceDataType() const {
+ return context_.ws_dt;
+ }
+ memory::format GetWorkspaceFormat() const {
+ return context_.ws_fmt;
+ }
+
+ private:
+ void Setup(const MklPoolingParams& bwdParams);
+ // Primitive reuse context for pooling bwd ops
+ struct PoolingBwdContext {
+ // algorithm
+ mkldnn::algorithm alg_kind;
+
+ // expected memory format
+ mkldnn::memory::format diff_src_fmt;
+ mkldnn::memory::format diff_dst_fmt;
+ mkldnn::memory::format ws_fmt;
+
+ // workspace attribute
+ mkldnn::memory::dims ws_dims;
+ mkldnn::memory::data_type ws_dt;
+
+ // MKL-DNN memory
+ std::shared_ptr<mkldnn::memory> ws_mem;
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // memory desc
+ std::shared_ptr<mkldnn::memory::desc> diff_src_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
+
+ // desc & primitive desc
+ std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc;
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
+ std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd;
+
+ // pooling primitive
+ std::shared_ptr<mkldnn::pooling_backward> bwd;
+ std::shared_ptr<mkldnn::stream> bwd_stream;
+
+ std::vector<mkldnn::primitive> bwd_primitives;
+
+ PoolingBwdContext() :
+ diff_src_fmt(memory::format::any), diff_dst_fmt(memory::format::any),
+ ws_fmt(memory::format::any), ws_mem(nullptr), diff_src_mem(nullptr),
+ diff_dst_mem(nullptr), diff_src_md(nullptr), diff_dst_md(nullptr),
+ fwd_desc(nullptr), bwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr),
+ bwd(nullptr), bwd_stream(nullptr) {
+ }
+ } context_;
+ // cpu engine
+ engine cpu_engine = engine(engine::cpu, 0);
+};
+
+template <typename T>
+class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklPoolingBwdPrimitive<T> *Get(const MklPoolingParams& bwdParams) {
+ MklPoolingBwdPrimitive<T>* pooling_backward = nullptr;
+
+ // Find a pooling backward primitive from the pool
+ // If it does not exist, create a new one
+ pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>(
+ MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd(bwdParams));
+ if (pooling_backward == nullptr) {
+ pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams);
+ MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd(
+ bwdParams, pooling_backward);
+ }
+ return pooling_backward;
+ }
+
+ static MklPoolingBwdPrimitiveFactory& GetInstance() {
+ static MklPoolingBwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklPoolingBwdPrimitiveFactory() {}
+ ~MklPoolingBwdPrimitiveFactory() {}
+
+ // The key to be created will be used to get/set pooling
+ // primitive op from reuse perspective.
+ // A pooling key is a string which concates key parameters
+ // as well as algorithm kind (max versus avg).
+ static std::string CreateKey(const MklPoolingParams& bwdParams) {
+ std::string prefix = "pooling_bwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(bwdParams.src_dims);
+ key_creator.AddAsKey(bwdParams.dst_dims);
+ key_creator.AddAsKey(bwdParams.filter_dims);
+ key_creator.AddAsKey(bwdParams.strides);
+ key_creator.AddAsKey(bwdParams.padding_left);
+ key_creator.AddAsKey(bwdParams.padding_right);
+ key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
+ std::string key = CreateKey(bwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive *op) {
+ std::string key = CreateKey(bwdParams);
+ this->SetOp(key, op);
+ }
+};
+#endif
+
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
@@ -351,7 +661,7 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
memory::desc ConfigureOriginalOutput(
const MklPoolParameters& pool_params,
const MklDnnShape& original_output_mkl_shape,
- memory::dims output_dims_mkl_order) {
+ const memory::dims& output_dims_mkl_order) {
this->GetOutputDims(pool_params, &output_dims_mkl_order);
return original_output_mkl_shape.IsMklTensor()