/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ #ifdef INTEL_MKL #include #include #include #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" #ifndef INTEL_MKL_ML_ONLY #include "mkldnn.hpp" using mkldnn::memory; using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::stream; #endif namespace tensorflow { #ifndef INTEL_MKL_ML_ONLY using mkldnn::memory; using mkldnn::pooling_avg; using mkldnn::pooling_avg_exclude_padding; using mkldnn::pooling_avg_include_padding; using mkldnn::pooling_max; using mkldnn::prop_kind; struct MklPoolingParams { memory::dims src_dims; memory::dims dst_dims; memory::dims filter_dims; memory::dims strides; memory::dims padding_left; memory::dims padding_right; mkldnn::algorithm alg_kind; MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, memory::dims padding_left, memory::dims padding_right, mkldnn::algorithm alg_kind) : src_dims(src_dims), dst_dims(dst_dims), filter_dims(filter_dims), strides(strides), padding_left(padding_left), padding_right(padding_right), alg_kind(alg_kind) {} }; template class MklPoolingFwdPrimitive : public MklPrimitive { public: explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) : cpu_engine_(engine::cpu, 0) { context_.fwd_stream.reset(new stream(stream::kind::eager)); if (context_.fwd == nullptr) Setup(fwdParams); } ~MklPoolingFwdPrimitive() {} // Pooling forward execute // src_data: input data buffer of src // ws_data: output data buffer of workspace // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr); std::shared_ptr GetPoolingFwdPd() const { return context_.fwd_pd; } memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } memory::format GetDstMemoryFormat() const { return context_.dst_fmt; } private: void Setup(const MklPoolingParams& fwdParams); struct PoolingFwdContext { // algorithm mkldnn::algorithm alg_kind; // expected memory format memory::format src_fmt; memory::format dst_fmt; memory::format ws_fmt; // workspace shape memory::dims ws_dims; memory::data_type ws_dt; size_t ws_size; // MKL-DNN memory, just dummy data std::shared_ptr ws_mem; std::shared_ptr src_mem; std::shared_ptr dst_mem; // desc & primitive desc std::shared_ptr fwd_desc; std::shared_ptr fwd_pd; // memory desc std::shared_ptr src_md; std::shared_ptr dst_md; // Pooling primitive std::shared_ptr fwd; std::shared_ptr fwd_stream; std::vector fwd_primitives; PoolingFwdContext() : src_fmt(memory::format::any), dst_fmt(memory::format::any), ws_fmt(memory::format::any), ws_mem(nullptr), src_mem(nullptr), dst_mem(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), fwd(nullptr), fwd_stream(nullptr) {} }; struct PoolingFwdContext context_; engine cpu_engine_; }; template class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { public: static MklPoolingFwdPrimitive* Get(const MklPoolingParams& fwdParams) { MklPoolingFwdPrimitive* pooling_forward = nullptr; // Get pooling primitive from the pool pooling_forward = static_cast*>( MklPoolingFwdPrimitiveFactory::GetInstance().GetPoolingFwd( fwdParams)); if (pooling_forward == nullptr) { pooling_forward = new MklPoolingFwdPrimitive(fwdParams); MklPoolingFwdPrimitiveFactory::GetInstance().SetPoolingFwd( fwdParams, pooling_forward); } return pooling_forward; } static MklPoolingFwdPrimitiveFactory& GetInstance() { static MklPoolingFwdPrimitiveFactory instance_; return instance_; } private: MklPoolingFwdPrimitiveFactory() {} ~MklPoolingFwdPrimitiveFactory() {} // The key to be created will be used to get/set pooling // primitive op from reuse perspective. // A pooling key is a string which concates key parameters // as well as algorithm kind (max versus avg). static string CreateKey(const MklPoolingParams& fwdParams) { string prefix = "pooling_fwd"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(fwdParams.src_dims); key_creator.AddAsKey(fwdParams.dst_dims); key_creator.AddAsKey(fwdParams.filter_dims); key_creator.AddAsKey(fwdParams.strides); key_creator.AddAsKey(fwdParams.padding_left); key_creator.AddAsKey(fwdParams.padding_right); key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); return key_creator.GetKey(); } MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { string key = CreateKey(fwdParams); return this->GetOp(key); } void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) { string key = CreateKey(fwdParams); this->SetOp(key, op); } }; template class MklPoolingBwdPrimitive : public MklPrimitive { public: explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) : cpu_engine(engine::cpu, 0) { context_.bwd_stream.reset(new stream(stream::kind::eager)); if (context_.bwd == nullptr) Setup(bwdParams); } ~MklPoolingBwdPrimitive() {} // Pooling backward execute // diff_dst_data: input data buffer of diff_dst // diff_src_data: output data buffer of diff_src // ws_data: input data buffer of workspace void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data = nullptr); public: std::shared_ptr GetPoolingFwdPd() const { return context_.fwd_pd; } std::shared_ptr GetPoolingBwdPd() const { return context_.bwd_pd; } memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; } mkldnn::memory::data_type GetWorkspaceDataType() const { return context_.ws_dt; } memory::format GetWorkspaceFormat() const { return context_.ws_fmt; } private: void Setup(const MklPoolingParams& bwdParams); // Primitive reuse context for pooling bwd ops struct PoolingBwdContext { // algorithm mkldnn::algorithm alg_kind; // expected memory format mkldnn::memory::format diff_src_fmt; mkldnn::memory::format diff_dst_fmt; mkldnn::memory::format ws_fmt; // workspace attribute mkldnn::memory::dims ws_dims; mkldnn::memory::data_type ws_dt; // MKL-DNN memory std::shared_ptr ws_mem; std::shared_ptr diff_src_mem; std::shared_ptr diff_dst_mem; // memory desc std::shared_ptr diff_src_md; std::shared_ptr diff_dst_md; // desc & primitive desc std::shared_ptr fwd_desc; std::shared_ptr bwd_desc; std::shared_ptr fwd_pd; std::shared_ptr bwd_pd; // pooling primitive std::shared_ptr bwd; std::shared_ptr bwd_stream; std::vector bwd_primitives; PoolingBwdContext() : diff_src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), ws_fmt(memory::format::any), ws_mem(nullptr), diff_src_mem(nullptr), diff_dst_mem(nullptr), diff_src_md(nullptr), diff_dst_md(nullptr), fwd_desc(nullptr), bwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), bwd(nullptr), bwd_stream(nullptr) {} }; struct PoolingBwdContext context_; engine cpu_engine; }; template class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { public: static MklPoolingBwdPrimitive* Get(const MklPoolingParams& bwdParams) { MklPoolingBwdPrimitive* pooling_backward = nullptr; // Find a pooling backward primitive from the pool // If it does not exist, create a new one pooling_backward = static_cast*>( MklPoolingBwdPrimitiveFactory::GetInstance().GetPoolingBwd( bwdParams)); if (pooling_backward == nullptr) { pooling_backward = new MklPoolingBwdPrimitive(bwdParams); MklPoolingBwdPrimitiveFactory::GetInstance().SetPoolingBwd( bwdParams, pooling_backward); } return pooling_backward; } static MklPoolingBwdPrimitiveFactory& GetInstance() { static MklPoolingBwdPrimitiveFactory instance_; return instance_; } private: MklPoolingBwdPrimitiveFactory() {} ~MklPoolingBwdPrimitiveFactory() {} // The key to be created will be used to get/set pooling // primitive op from reuse perspective. // A pooling key is a string which concates key parameters // as well as algorithm kind (max versus avg). static string CreateKey(const MklPoolingParams& bwdParams) { string prefix = "pooling_bwd"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(bwdParams.src_dims); key_creator.AddAsKey(bwdParams.dst_dims); key_creator.AddAsKey(bwdParams.filter_dims); key_creator.AddAsKey(bwdParams.strides); key_creator.AddAsKey(bwdParams.padding_left); key_creator.AddAsKey(bwdParams.padding_right); key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); return key_creator.GetKey(); } MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { string key = CreateKey(bwdParams); return this->GetOp(key); } void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) { string key = CreateKey(bwdParams); this->SetOp(key, op); } }; #endif typedef Eigen::ThreadPoolDevice CPUDevice; struct MklPoolParameters { int depth; int tensor_in_planes; // Pool3D int tensor_in_cols; int tensor_in_rows; int tensor_in_batch; int window_planes; // Pool3D int window_rows; int window_cols; int depth_window; int planes_stride; // Pool3D int row_stride; int col_stride; int depth_stride; int64 out_planes; // Pool3D int64 out_height; int64 out_width; int out_depth; int64 pad_P1; // Pool3D int64 pad_P2; // Pool3D int64 pad_left; int64 pad_right; int64 pad_top; int64 pad_bottom; int pad_depth; TensorFormat data_format; MklPoolParameters() : depth(0), tensor_in_planes(0), tensor_in_cols(0), tensor_in_rows(0), tensor_in_batch(0), window_planes(0), window_rows(0), window_cols(0), depth_window(0), planes_stride(0), row_stride(0), col_stride(0), depth_stride(0), out_planes(0), out_height(0), out_width(0), out_depth(0), pad_P1(0), pad_P2(0), pad_left(0), pad_right(0), pad_top(0), pad_bottom(0), pad_depth(0), data_format(TensorFormat::FORMAT_NCHW) {} // Updates context->status if there is an invalid input. void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const TensorShape& tensor_in_shape); #ifdef INTEL_MKL_ML_ONLY void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const MklShape* mkl_in_shape); #else void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const MklDnnShape* mkl_in_shape); #endif private: // Common initialization for TensorFlow and MKL formats void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format); }; #ifndef INTEL_MKL_ML_ONLY template class MklPoolingOpBase : public OpKernel { public: explicit MklPoolingOpBase(OpKernelConstruction* context) : OpKernel(context), workspace_enabled_(false) { string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), errors::InvalidArgument("Invalid data format")); OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5, errors::InvalidArgument("Sliding window ksize field must " "specify 4 or 5 dimensions")); OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5, errors::InvalidArgument("Sliding window strides field must " "specify 4 or 5 dimensions")); OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, errors::Unimplemented("Pooling is not yet supported on the " "batch dimension.")); bool is_pool2d = (this->ksize_.size() == 4); this->data_format_mkldnn_ = is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_) : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_); // We may not get this attribute for this node if it does not go through // graph rewrite pass. So we do not check for error while retrieving this // attribute value. context->GetAttr("workspace_enabled", &this->workspace_enabled_); } void Compute(OpKernelContext* context) override = 0; protected: // Calculate output shape of pooling op in MKL-DNN and TensorFlow order. // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order. // But TensorFlow output will be in NHWC/NCHW(Pool2D) or // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects // output height and width to have already been int32 bounds-checked. void GetOutputDims(const MklPoolParameters& mkl_pool_params, memory::dims* output_dims_mkl_order) { if (this->ksize_.size() == 4) { // Pooling2D: MKL-DNN always needs output in NCHW format. *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, mkl_pool_params.out_depth, static_cast(mkl_pool_params.out_height), static_cast(mkl_pool_params.out_width)}; } else { // Pooling3D: MKL-DNN always needs output in NCDHW format. *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, mkl_pool_params.out_depth, static_cast(mkl_pool_params.out_planes), static_cast(mkl_pool_params.out_height), static_cast(mkl_pool_params.out_width)}; } } void InitMklPoolParameters(OpKernelContext* context, MklPoolParameters* pool_params, const MklDnnShape& original_input_mkl_shape, const TensorShape& input_tensor_shape) { if (!original_input_mkl_shape.IsMklTensor()) { pool_params->Init(context, this->ksize_, this->stride_, this->padding_, this->data_format_tf_, input_tensor_shape); } else { pool_params->Init(context, this->ksize_, this->stride_, this->padding_, this->data_format_tf_, &original_input_mkl_shape); } } void PoolParamsToDims(const MklPoolParameters* pool_params, memory::dims* filter_dims, memory::dims* strides, memory::dims* padding_left, memory::dims* padding_right, bool is_pool2d) { if (is_pool2d) { // Pool2D *filter_dims = memory::dims({pool_params->window_rows, pool_params->window_cols}); *strides = memory::dims({pool_params->row_stride, pool_params->col_stride}); *padding_left = memory::dims({static_cast(pool_params->pad_top), static_cast(pool_params->pad_left)}); *padding_right = memory::dims({static_cast(pool_params->pad_bottom), static_cast(pool_params->pad_right)}); } else { // Pool3D *filter_dims = memory::dims({pool_params->window_planes, pool_params->window_rows, pool_params->window_cols}); *strides = memory::dims({pool_params->planes_stride, pool_params->row_stride, pool_params->col_stride}); *padding_left = memory::dims({static_cast(pool_params->pad_P1), static_cast(pool_params->pad_top), static_cast(pool_params->pad_left)}); *padding_right = memory::dims({static_cast(pool_params->pad_P2), static_cast(pool_params->pad_bottom), static_cast(pool_params->pad_right)}); } } void AllocateEmptyOutputTensor(OpKernelContext* context, const int kOutputIndex, MklPoolParameters* pool_params, const memory::dims output_dims_mkl_order, Tensor** output_tensor) { MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(false); TensorShape output_tf_shape; if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); } else { memory::dims output_dims_NHWC_order; output_dims_NHWC_order = {pool_params->tensor_in_batch, static_cast(pool_params->out_height), static_cast(pool_params->out_width), pool_params->out_depth}; output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); } AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, output_tf_shape, output_mkl_shape); CHECK_NOTNULL(output_tensor); } // Checks to make sure that the memory we need to allocate // is a multiple of sizeof(T) // returns the number of elements size_t GetNumTElements(const memory::primitive_desc& pd) { size_t num_bytes = pd.get_size(); size_t ret_val = num_bytes / sizeof(T); if (num_bytes % sizeof(T) != 0) { ret_val++; } return ret_val; } std::vector ksize_; std::vector stride_; Padding padding_; TensorFormat data_format_tf_; memory::format data_format_mkldnn_; bool workspace_enabled_; }; template class MklPoolingForwardOpBase : public MklPoolingOpBase { public: explicit MklPoolingForwardOpBase(OpKernelConstruction* context) : MklPoolingOpBase(context) {} void Compute(OpKernelContext* context) override = 0; protected: void ConfigureInput(OpKernelContext* context, const MklDnnShape& input_mkl_shape, const Tensor& input_tensor, MklPoolParameters* pool_params, MklDnnData* dnn_data_input) { CHECK_NOTNULL(pool_params); CHECK_NOTNULL(dnn_data_input); TensorShape input_tensor_shape = input_tensor.shape(); if (input_tensor.NumElements() != 0) { memory::desc input_md = input_mkl_shape.IsMklTensor() ? input_mkl_shape.GetMklLayout() : memory::desc( (this->ksize_.size() == 4) ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape, this->data_format_tf_) : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape, this->data_format_tf_), MklDnnType(), this->data_format_mkldnn_); dnn_data_input->SetUsrMem(input_md, &input_tensor); if (this->ksize_.size() == 5) { // Pool3D std::vector mkldnn_sizes(5, -1); mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0]; mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1]; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2]; mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3]; mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4]; dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_); } } this->InitMklPoolParameters(context, pool_params, input_mkl_shape, input_tensor_shape); } void AllocateOutputTensor( OpKernelContext* context, const pooling_forward::primitive_desc& pool_fwd_prim_desc, const memory::dims output_dims_mkl_order, const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc(); MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(true); output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; // only allocate enough space for the elements we need. output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, output_tf_shape, output_mkl_shape); CHECK_NOTNULL(*output_tensor); } void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, const MklDnnShape& input_mkl_shape) { if (!input_mkl_shape.IsMklTensor()) { OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5, errors::InvalidArgument("Input must be 4 or 5-dimensional")); } else { OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4 || input_mkl_shape.GetDimension() == 5, errors::InvalidArgument("Input shape must be " "4 or 5-dimensional")); } } // .Input("value: T") // .Output("output: T") const int kInputTensorIndexInput = 0; const int kOutputTensorIndexOutput = 0; }; // MklPoolingForwardBaseOp template class MklPoolingBackwardOpBase : public MklPoolingOpBase { public: explicit MklPoolingBackwardOpBase(OpKernelConstruction* context) : MklPoolingOpBase(context) {} void Compute(OpKernelContext* context) override = 0; protected: const int kOutputTensorIndexOutput = 0; void AllocateOutputTensor( OpKernelContext* context, const pooling_backward::primitive_desc& pool_bkwd_prim_desc, const memory::dims output_dims_mkl_order, const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); memory::primitive_desc dst_pd = pool_bkwd_prim_desc.diff_src_primitive_desc(); MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(true); output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, output_tf_shape, output_mkl_shape); CHECK_NOTNULL(*output_tensor); } memory::desc ConfigureInputGradient( const MklDnnShape& input_gradient_mkl_shape, const Tensor& input_gradient_tensor, MklDnnData* input_gradient_dnn_data, const memory::desc& original_output_md) { // Configure the gradient as is memory::desc original_input_grad_md = input_gradient_mkl_shape.IsMklTensor() ? input_gradient_mkl_shape.GetMklLayout() : memory::desc( (this->ksize_.size() == 4) ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), this->data_format_tf_) : TFShapeToMklDnnDimsInNCDHW( input_gradient_tensor.shape(), this->data_format_tf_), MklDnnType(), this->data_format_mkldnn_); input_gradient_dnn_data->SetUsrMem(original_input_grad_md, &input_gradient_tensor); // Check to see if input grad diff dst is in the right format // Create a new memory descriptor with the same shape as the // original, but the format of the other tensors. memory::format original_output_format = static_cast(original_output_md.data.format); bool grad_reorder_needed = input_gradient_dnn_data->IsReorderNeeded(original_output_format); memory::dims diff_dst_dims = input_gradient_mkl_shape.IsMklTensor() ? input_gradient_mkl_shape.GetSizesAsMklDnnDims() : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), this->data_format_tf_); memory::desc target_diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), original_output_format); return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md; } }; #endif // INTEL_MKL_ML_ONLY //------------------------------------------------------------------- // Utility functions typedef struct { size_t in_dim; size_t in_sizes[4]; size_t in_strides[4]; size_t out_sizes[4]; size_t out_strides[4]; int in_offset[4]; size_t kernel_stride[2]; size_t kernel_size[2]; } MklPoolingOpParams; // Transfers the right parameters for pooling to the op parameters // Updates context->status if there is an invalid input. void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format, const MklPoolParameters& params, MklPoolingOpParams* mkl_params); } // namespace tensorflow #endif // INTEL_MKL #endif // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_