diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.cc | 193 |
1 files changed, 97 insertions, 96 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index cede0b9dd6..b568973220 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -70,23 +70,25 @@ struct MklConvFwdParams { memory::dims padding_left; memory::dims padding_right; - MklConvFwdParams(memory::dims src_dims, - memory::dims filter_dims, memory::dims bias_dims, - memory::dims dst_dims, memory::dims strides, - memory::dims dilations, memory::dims padding_left, - memory::dims padding_right) : - src_dims(src_dims), filter_dims(filter_dims), - bias_dims(bias_dims), dst_dims(dst_dims), - strides(strides), dilations(dilations), - padding_left(padding_left), padding_right(padding_right) { - } + MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, + memory::dims bias_dims, memory::dims dst_dims, + memory::dims strides, memory::dims dilations, + memory::dims padding_left, memory::dims padding_right) + : src_dims(src_dims), + filter_dims(filter_dims), + bias_dims(bias_dims), + dst_dims(dst_dims), + strides(strides), + dilations(dilations), + padding_left(padding_left), + padding_right(padding_right) {} }; template <typename T> -class MklConv2DFwdPrimitive: public MklPrimitive { +class MklConv2DFwdPrimitive : public MklPrimitive { public: - explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) : - cpu_engine_(engine::cpu, 0) { + explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) + : cpu_engine_(engine::cpu, 0) { context_.fwd_stream.reset(new stream(stream::kind::eager)); // create conv primitive if (context_.conv_fwd == nullptr) { @@ -101,8 +103,8 @@ class MklConv2DFwdPrimitive: public MklPrimitive { // filter_data: input data buffer of filter (weights) // bias_data: input data buffer of bias // dst_data: output data buffer of dst - void Execute(const T* src_data, const T* filter_data, - const T* bias_data, const T* dst_data) { + void Execute(const T* src_data, const T* filter_data, const T* bias_data, + const T* dst_data) { context_.src_mem->set_data_handle( static_cast<void*>(const_cast<T*>(src_data))); context_.filter_mem->set_data_handle( @@ -126,8 +128,7 @@ class MklConv2DFwdPrimitive: public MklPrimitive { // src_data: input data buffer of src // filter_data: input data buffer of filter (weights) // dst_data: output data buffer of dst - void Execute(const T* src_data, const T* filter_data, - const T* dst_data) { + void Execute(const T* src_data, const T* filter_data, const T* dst_data) { context_.src_mem->set_data_handle( static_cast<void*>(const_cast<T*>(src_data))); context_.filter_mem->set_data_handle( @@ -142,13 +143,9 @@ class MklConv2DFwdPrimitive: public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); } - memory::format GetSrcMemoryFormat() const { - return context_.src_fmt; - } + memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } - memory::format GetFilterMemoryFormat() const { - return context_.filter_fmt; - } + memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetPrimitiveDesc() const { @@ -184,43 +181,50 @@ class MklConv2DFwdPrimitive: public MklPrimitive { std::shared_ptr<mkldnn::stream> fwd_stream; std::vector<mkldnn::primitive> fwd_primitives; - ConvFwdContext() : - src_fmt(memory::format::any), filter_fmt(memory::format::any), - src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr), - dst_mem(nullptr), fwd_desc(nullptr), - src_md(nullptr), filter_md(nullptr), bias_md(nullptr), - fwd_pd(nullptr), conv_fwd(nullptr), fwd_stream(nullptr) { - } + ConvFwdContext() + : src_fmt(memory::format::any), + filter_fmt(memory::format::any), + src_mem(nullptr), + filter_mem(nullptr), + bias_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + src_md(nullptr), + filter_md(nullptr), + bias_md(nullptr), + fwd_pd(nullptr), + conv_fwd(nullptr), + fwd_stream(nullptr) {} }; void Setup(const MklConvFwdParams& convFwdDims) { // create memory descriptors for convolution data w/ no specified format - context_.src_md.reset(new memory::desc({convFwdDims.src_dims}, - MklDnnType<T>(), memory::format::any)); + context_.src_md.reset(new memory::desc( + {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any)); - context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims}, - MklDnnType<T>(), memory::format::any)); + context_.filter_md.reset(new memory::desc( + {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any)); - context_.dst_md.reset(new memory::desc({convFwdDims.dst_dims}, - MklDnnType<T>(), memory::format::any)); + context_.dst_md.reset(new memory::desc( + {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any)); if (!convFwdDims.bias_dims.empty()) - context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, - MklDnnType<T>(), memory::format::any)); + context_.bias_md.reset(new memory::desc( + {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any)); // create a convolution if (!convFwdDims.bias_dims.empty()) { - context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward, - convolution_direct, *context_.src_md, *context_.filter_md, - *context_.bias_md, *context_.dst_md, + context_.fwd_desc.reset(new convolution_forward::desc( + prop_kind::forward, convolution_direct, *context_.src_md, + *context_.filter_md, *context_.bias_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); } else { - context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward, - convolution_direct, *context_.src_md, *context_.filter_md, - *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, - convFwdDims.padding_left, convFwdDims.padding_right, - padding_kind::zero)); + context_.fwd_desc.reset(new convolution_forward::desc( + prop_kind::forward, convolution_direct, *context_.src_md, + *context_.filter_md, *context_.dst_md, convFwdDims.strides, + convFwdDims.dilations, convFwdDims.padding_left, + convFwdDims.padding_right, padding_kind::zero)); } context_.fwd_pd.reset(new convolution_forward::primitive_desc( @@ -234,24 +238,26 @@ class MklConv2DFwdPrimitive: public MklPrimitive { context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); // create memory primitive based on dummy data - context_.src_mem.reset(new memory( - context_.fwd_pd.get()->src_primitive_desc(), DummyData)); - context_.filter_mem.reset(new memory( - context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); - context_.dst_mem.reset(new memory( - context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); + context_.filter_mem.reset( + new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); // create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory({{{convFwdDims.bias_dims}, - MklDnnType<T>(), memory::format::x}, cpu_engine_}, DummyData)); - context_.conv_fwd.reset(new convolution_forward( - *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, - *context_.bias_mem, *context_.dst_mem)); + context_.bias_mem.reset(new memory( + {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x}, + cpu_engine_}, + DummyData)); + context_.conv_fwd.reset(new convolution_forward( + *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, + *context_.bias_mem, *context_.dst_mem)); } else { - context_.conv_fwd.reset(new convolution_forward( - *context_.fwd_pd, *context_.src_mem, - *context_.filter_mem, *context_.dst_mem)); + context_.conv_fwd.reset( + new convolution_forward(*context_.fwd_pd, *context_.src_mem, + *context_.filter_mem, *context_.dst_mem)); } context_.fwd_primitives.push_back(*context_.conv_fwd); @@ -266,19 +272,19 @@ template <typename T> class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { public: static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) { - MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; - - // try to find a suitable one in pool - conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*> ( - MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd( - convFwdDims)); - - if (conv2d_fwd == nullptr) { - conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims); - MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd( - convFwdDims, conv2d_fwd); - } - return conv2d_fwd; + MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; + + // try to find a suitable one in pool + conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>( + MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd( + convFwdDims)); + + if (conv2d_fwd == nullptr) { + conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims); + MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims, + conv2d_fwd); + } + return conv2d_fwd; } private: @@ -312,7 +318,7 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { return this->GetOp(key); } - void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive *op) { + void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { std::string key = CreateKey(convFwdDims); this->SetOp(key, op); } @@ -865,22 +871,24 @@ class MklConv2DOp : public OpKernel { dilations[kDilationW] -= 1; // get a conv2d fwd from primitive pool - MklConv2DFwdPrimitive<T> *conv2d_fwd = nullptr; + MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; if (biasEnabled) { memory::dims bias_dims = {}; conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims, - dst_dims_mkl_order, strides, dilations, padding_left, padding_right); + dst_dims_mkl_order, strides, dilations, + padding_left, padding_right); conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); } else { MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS, - dst_dims_mkl_order, strides, dilations, padding_left, padding_right); + dst_dims_mkl_order, strides, dilations, + padding_left, padding_right); conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); } // allocate output tensors output_tensor and filter_out_tensor - std::shared_ptr<mkldnn::convolution_forward::primitive_desc> - conv_fwd_pd = conv2d_fwd->GetPrimitiveDesc(); + std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd = + conv2d_fwd->GetPrimitiveDesc(); AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, &dst_tensor); Tensor* filter_out_tensor = nullptr; @@ -891,31 +899,25 @@ class MklConv2DOp : public OpKernel { T* dst_data = static_cast<T*>(dst_tensor->flat<T>().data()); // check whether src/filter need reorder - std::vector<primitive> net; T *src_data = nullptr; if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem( - conv_fwd_pd.get()->src_primitive_desc(), &net); + src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc()); src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); } else { - src_data = static_cast<T*>(const_cast<T*>( - src_tensor.flat<T>().data())); + src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); } - T *filter_data = nullptr; + T* filter_data = nullptr; if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) { filter.SetUsrMem(filter_md, &filter_tensor); - filter.CheckReorderToOpMem( - conv_fwd_pd.get()->weights_primitive_desc(), - filter.GetTensorBuffer(filter_out_tensor), &net); + filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(), + filter.GetTensorBuffer(filter_out_tensor)); filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); } else { - filter_data = static_cast<T*>(const_cast<T*>( - filter_tensor.flat<T>().data())); + filter_data = + static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data())); } - stream(stream::kind::eager).submit(net).wait(); - // execute convolution if (biasEnabled) { @@ -1010,16 +1012,15 @@ class MklConv2DOp : public OpKernel { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. No need to check for output // reorder as we propagate output layout to the next layer. - std::vector<primitive> net; - src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net); + src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc()); // rather than re-order to a temp buffer, reorder directly to the // filter output tensor filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), - filter->GetTensorBuffer(filter_out_tensor), - &net); + filter->GetTensorBuffer(filter_out_tensor)); // Create convolution primitive and add it to net. + std::vector<primitive> net; if (bias) { CHECK_EQ(biasEnabled, true); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), |