diff options
author | Tatiana Shpeisman <shpeisman@google.com> | 2018-07-03 18:09:35 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-03 18:09:35 -0700 |
commit | b2fe2a874bade4782aaca5c44bf29e7ff6c39200 (patch) | |
tree | 77a6f54a3e40bd8be2a8fe005fed56f27751c044 | |
parent | 3b538660b1eb22e52ad455a17e01598508373969 (diff) | |
parent | 56150c9829b79c2249a4b90087ce25b1e6624f0b (diff) |
Merge pull request #19399 from Intel-tensorflow/primreuse_conv_bwd
INTEL-MKL: Enhance Mkl conv2d backward (filter and input) ops with primitive reuse
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 659 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 476 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 222 |
3 files changed, 932 insertions, 425 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 356eed8b67..4e80f5acce 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -54,9 +54,311 @@ using mkldnn::stream; #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { - typedef Eigen::ThreadPoolDevice CPUDevice; +#ifndef INTEL_MKL_ML + +struct MklConvBwdFilterParams { + memory::dims src_dims; + memory::dims diff_filter_dims; + memory::dims diff_bias_dims; + memory::dims diff_dst_dims; + memory::dims strides; + memory::dims dilations; + memory::dims padding_left; + memory::dims padding_right; + padding_kind padding; + + MklConvBwdFilterParams(memory::dims src_dims, + memory::dims diff_filter_dims, memory::dims diff_bias_dims, + memory::dims diff_dst_dims, memory::dims strides, + memory::dims dilations, memory::dims padding_left, + memory::dims padding_right, padding_kind padding) : + src_dims(src_dims), diff_filter_dims(diff_filter_dims), + diff_bias_dims(diff_bias_dims), diff_dst_dims(diff_dst_dims), + strides(strides), dilations(dilations), + padding_left(padding_left), padding_right(padding_right), + padding(padding) { + } +}; + +template <typename T> +class MklConv2DBwdFilterPrimitive : public MklPrimitive { + public: + explicit MklConv2DBwdFilterPrimitive( + const MklConvBwdFilterParams& convBwdFilterDims) : + cpu_engine_(engine::cpu, 0) { + context_.bwd_filter_stream.reset(new stream(stream::kind::eager)); + // create conv primitive + if (context_.conv_bwd_filter == nullptr) { + Setup(convBwdFilterDims); + } + } + + ~MklConv2DBwdFilterPrimitive() {} + + // Convolution backward weights with bias + // src_data: input data buffer of src + // diff_filter_data: output data buffer of diff_filter + // diff_bias_data: output data buffer of diff_bias + // diff_dst_data: input data buffer of diff_dst + void Execute(const T* src_data, const T* diff_filter_data, + const T* diff_bias_data, const T* diff_dst_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.diff_filter_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_filter_data))); + context_.diff_bias_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_bias_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); + + context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.diff_filter_mem->set_data_handle(DummyData); + context_.diff_bias_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + return; + } + + // Convolution backward weights without bias + // src_data: input data buffer of src + // diff_filter_data: output data buffer of diff_filter + // diff_dst_data: input data buffer of diff_dst + void Execute(const T* src_data, + const T* diff_filter_data, const T* diff_dst_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.diff_filter_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_filter_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); + + context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.diff_filter_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + return; + } + + memory::format GetSrcMemoryFormat() const { + return context_.src_fmt; + } + + memory::format GetDiffDstMemoryFormat() const { + return context_.diff_dst_fmt; + } + + memory::format GetDiffFilterMemoryFormat() const { + return context_.diff_filter_fmt; + } + + // convolution primitive + std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> + GetPrimitiveDesc() const { + return context_.bwd_filter_pd; + } + + private: + // Primitive reuse context for Conv2D bwd filter op + struct ConvBwdFilterContext { + // expected memory format for this primitive instance + memory::format src_fmt; + memory::format diff_dst_fmt; + memory::format diff_filter_fmt; + + // convolution bwd input primitive + std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> + bwd_filter_pd; + std::shared_ptr<mkldnn::primitive> conv_bwd_filter; + + // MKLDNN memory + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> diff_filter_mem; + std::shared_ptr<mkldnn::memory> diff_bias_mem; + std::shared_ptr<mkldnn::memory> diff_dst_mem; + + // desc & prmitive desc + std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc; + std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc; + std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd; + + // memory desc: forward & backward can share same memory desc + std::shared_ptr<mkldnn::memory::desc> src_md; + std::shared_ptr<mkldnn::memory::desc> diff_filter_md; + std::shared_ptr<mkldnn::memory::desc> diff_bias_md; + std::shared_ptr<mkldnn::memory::desc> diff_dst_md; + + // MKL pipeline + std::shared_ptr<mkldnn::stream> bwd_filter_stream; + std::vector<mkldnn::primitive> bwd_filter_primitives; + + ConvBwdFilterContext() : + src_fmt(memory::format::any), + diff_dst_fmt(memory::format::any), + diff_filter_fmt(memory::format::any), + src_mem(nullptr), diff_filter_mem(nullptr), + diff_bias_mem(nullptr), diff_dst_mem(nullptr), + bwd_filter_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), + src_md(nullptr), diff_filter_md(nullptr), + diff_bias_md(nullptr), diff_dst_md(nullptr), + bwd_filter_stream(nullptr) { + } + }; + + // Setup Conv2d backward filter (weights) primitives. + void Setup(const MklConvBwdFilterParams& convBwdFilterDims) { + // create memory descriptors for convolution data w/ no specified format + context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims}, + MklDnnType<T>(), memory::format::any)); + + context_.diff_dst_md.reset(new memory::desc( + {convBwdFilterDims.diff_dst_dims}, + MklDnnType<T>(), memory::format::any)); + + context_.diff_filter_md.reset(new memory::desc( + {convBwdFilterDims.diff_filter_dims}, + MklDnnType<T>(), memory::format::any)); + + if (!convBwdFilterDims.diff_bias_dims.empty()) + context_.diff_bias_md.reset(new memory::desc( + {convBwdFilterDims.diff_bias_dims}, + MklDnnType<T>(), memory::format::x)); + + // create a convolution + if (!convBwdFilterDims.diff_bias_dims.empty()) { + context_.bwd_filter_desc.reset(new convolution_backward_weights::desc( + convolution_direct, *context_.src_md, *context_.diff_filter_md, + *context_.diff_bias_md, *context_.diff_dst_md, + convBwdFilterDims.strides, convBwdFilterDims.dilations, + convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, + convBwdFilterDims.padding)); + } else { + context_.bwd_filter_desc.reset( + new convolution_backward_weights::desc( + convolution_direct, *context_.src_md, *context_.diff_filter_md, + *context_.diff_dst_md, convBwdFilterDims.strides, + convBwdFilterDims.dilations, convBwdFilterDims.padding_left, + convBwdFilterDims.padding_right, convBwdFilterDims.padding)); + } + + // create fwd primitive_desc + context_.fwd_desc.reset(new convolution_forward::desc( + prop_kind::forward, convolution_direct, + *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md, + convBwdFilterDims.strides, + convBwdFilterDims.dilations, convBwdFilterDims.padding_left, + convBwdFilterDims.padding_right, convBwdFilterDims.padding)); + context_.fwd_pd.reset(new convolution_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + + // create backward conv primitive_desc + context_.bwd_filter_pd.reset( + new convolution_backward_weights::primitive_desc( + *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd)); + + // store the expected memory format + auto bwd_filter_pd = context_.bwd_filter_pd.get(); + context_.src_fmt = static_cast<mkldnn::memory::format>( + bwd_filter_pd->src_primitive_desc().desc().data.format); + context_.diff_filter_fmt = static_cast<mkldnn::memory::format>( + bwd_filter_pd->diff_weights_primitive_desc().desc().data.format); + context_.diff_dst_fmt = static_cast<mkldnn::memory::format>( + bwd_filter_pd->diff_dst_primitive_desc().desc().data.format); + + // create memory primitive based on dummy data + context_.src_mem.reset(new memory( + bwd_filter_pd->src_primitive_desc(), DummyData)); + context_.diff_filter_mem.reset(new memory( + bwd_filter_pd->diff_weights_primitive_desc(), DummyData)); + context_.diff_dst_mem.reset(new memory( + bwd_filter_pd->diff_dst_primitive_desc(), DummyData)); + + // create convolution primitive and add it to net + if (!convBwdFilterDims.diff_bias_dims.empty()) { + context_.diff_bias_mem.reset(new memory( + {{{convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(), + memory::format::x}, cpu_engine_}, DummyData)); + context_.conv_bwd_filter.reset(new convolution_backward_weights( + *context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem, + *context_.diff_filter_mem, *context_.diff_bias_mem)); + } else { + context_.conv_bwd_filter.reset(new convolution_backward_weights( + *context_.bwd_filter_pd, *context_.src_mem, + *context_.diff_dst_mem, *context_.diff_filter_mem)); + } + + context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter); + } + + struct ConvBwdFilterContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklConv2DBwdFilterPrimitive<T>* Get( + const MklConvBwdFilterParams& convBwdFilterDims) { + MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr; + + // look into the pool for reusable primitive + conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> ( + MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter( + convBwdFilterDims)); + + if (conv2d_bwd_filter == nullptr) { + conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>( + convBwdFilterDims); + MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter( + convBwdFilterDims, conv2d_bwd_filter); + } + return conv2d_bwd_filter; + } + + + private: + MklConv2DBwdFilterPrimitiveFactory() {} + ~MklConv2DBwdFilterPrimitiveFactory() {} + + static MklConv2DBwdFilterPrimitiveFactory& GetInstance() { + static MklConv2DBwdFilterPrimitiveFactory instance_; + return instance_; + } + + static std::string CreateKey( + const MklConvBwdFilterParams& convBwdFilterDims) { + std::string prefix = "conv2d_bwd_filter"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(convBwdFilterDims.src_dims); + key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims); + key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims); + key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims); + key_creator.AddAsKey(convBwdFilterDims.strides); + key_creator.AddAsKey(convBwdFilterDims.dilations); + key_creator.AddAsKey(convBwdFilterDims.padding_left); + key_creator.AddAsKey(convBwdFilterDims.padding_right); + return key_creator.GetKey(); + } + + MklPrimitive* GetConv2dBwdFilter( + const MklConvBwdFilterParams& convBwdFilterDims) { + std::string key = CreateKey(convBwdFilterDims); + return this->GetOp(key); + } + + void SetConv2dBwdFilter( + const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) { + std::string key = CreateKey(convBwdFilterDims); + this->SetOp(key, op); + } +}; + +#endif + #ifdef INTEL_MKL_ML template <typename Device, class T> @@ -442,11 +744,213 @@ class MklConv2DCustomBackpropFilterOp : public MklConv2DBackpropCommonOp<Device, T> { public: explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) {} + : MklConv2DBackpropCommonOp<Device, T>(context) { + } + ~MklConv2DCustomBackpropFilterOp() {} + void Compute(OpKernelContext* context) { + try { + MklDnnData<T> src(&cpu_engine_); + MklDnnData<T> diff_dst(&cpu_engine_); + MklDnnData<T> diff_filter(&cpu_engine_); // output + + // Input tensors + const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; + const Tensor& src_tensor = MklGetInput(context, kInputIdx); + const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); + const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx); + + MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape; + GetMklShape(context, kInputIdx, &src_mkl_shape); + GetMklShape(context, kFilterIdx, &filter_mkl_shape); + GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape); + // Allow operator-specific sanity checking of shapes. + ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape); + + // Allow operator-specific generation of shapes. + // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a + // tensor containing shape of filter. So filter.shape() is not + // a correct way to get filter shape. These operator-specific calls + // allow this class to handle this case. + TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor); + TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); + TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx); + + // Corner cases: output with 0 elements and 0 batch size. + Tensor* diff_filter_tensor = nullptr; + if (src_tf_shape.num_elements() == 0 || + filter_tf_shape.num_elements() == 0 || + diff_dst_tf_shape.num_elements() == 0) { + MklDnnShape diff_filter_mkl_shape; + diff_filter_mkl_shape.SetMklTensor(false); + TensorShape diff_filter_tf_shape = GetOutputTfShape( + src_tf_shape, filter_tf_shape, diff_dst_tf_shape); + const int kOutputIdx = 0; + AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor, + diff_filter_tf_shape, diff_filter_mkl_shape); + CHECK_NOTNULL(diff_filter_tensor); + + // if output tensor has more than 0 elements, we need to 0 them out. + auto diff_filter_data = diff_filter_tensor->flat<T>().data(); + for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) { + diff_filter_data[i] = 0; + } + return; + } + + // By default, all dims are in MKL order. Only dims in TF order + // are those with prefix tf_order. + memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; + memory::dims padding_left, padding_right, dilations, + strides, fwd_dst_dims; + memory::dims fwd_dst_dims_tf_order; + + // Get forward convolution parameters. + MklDnnConvUtil conv_utl(context, this->strides_, this->padding_, + this->data_format_, this->dilations_); + conv_utl.GetConvFwdSizesInMklOrder( + src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims, + &strides, &dilations, &fwd_dst_dims_tf_order, + &fwd_dst_dims, &padding_left, &padding_right); + if (!context->status().ok()) return; + + auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_); + auto fwd_src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt); + + conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + if (!context->status().ok()) return; + + auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor() + ? diff_dst_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, + MklDnnType<T>(), tf_fmt); + + memory::dims diff_bias_dims = {}; + int64 depth = 0; + if (biasEnabled) { + TensorShape obp_tf_shape = GetTfShape(context, 2); + depth = (this->data_format_ == FORMAT_NCHW) + ? obp_tf_shape.dim_size(1) + : obp_tf_shape.dim_size(3); + diff_bias_dims = {static_cast<int>(depth)}; + } + + dilations[kDilationH] -= 1; + dilations[kDilationW] -= 1; + + MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr; + MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims, + diff_bias_dims, diff_dst_dims, strides, dilations, padding_left, + padding_right, TFPaddingToMklDnnPadding(this->padding_)); + conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get( + convBwdFilterDims); + auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc(); + + // allocate output tensors: diff_fitler and diff_bias (w bias) + auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); + + // diff_filter + MklDnnShape diff_filter_mkl_shape; + diff_filter_mkl_shape.SetMklTensor(false); + // output_dims_mkl_order is in OIHW format. + TensorShape diff_filter_tf_shape( + {bwd_output_dims[MklDnnDims::Dim_H], + bwd_output_dims[MklDnnDims::Dim_W], + bwd_output_dims[MklDnnDims::Dim_I], + bwd_output_dims[MklDnnDims::Dim_O]}); + AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, + diff_filter_tf_shape, diff_filter_mkl_shape); + + Tensor* diff_bias_tensor = nullptr; + if (biasEnabled) { + TensorShape diff_bias_shape({depth}); + AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor); + } + + // check if src and diff_dst need reorder + std::vector<primitive> net; + T *src_data = nullptr; + if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) { + src.SetUsrMem(fwd_src_md, &src_tensor); + src.CheckReorderToOpMem( + bwd_filter_pd->src_primitive_desc(), &net); + src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast<T*>(const_cast<T*>( + src_tensor.flat<T>().data())); + } + + T *diff_dst_data = nullptr; + if (diff_dst_md.data.format != + conv2d_bwd_filter->GetDiffDstMemoryFormat()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + diff_dst.CheckReorderToOpMem( + bwd_filter_pd->diff_dst_primitive_desc(), &net); + diff_dst_data = static_cast<T*>( + diff_dst.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast<T*>(const_cast<T*>( + diff_dst_tensor.flat<T>().data())); + } + stream(stream::kind::eager).submit(net).wait(); + + // For backward filter, convert diff_filter back to Tensorflow layout + // Here we prepare to reorder op memory back to user memory + bool diff_filter_reorder_required = false; + T *diff_filter_data = nullptr; + if (GetOutputFormat(tf_fmt) != + conv2d_bwd_filter->GetDiffFilterMemoryFormat()) { + // Allocate diff filter tensor as Tensorflow layout + diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt), + diff_filter_tensor); + diff_filter_reorder_required = true; + diff_filter.PrepareReorderToUserMemIfReq( + bwd_filter_pd->diff_weights_primitive_desc()); + diff_filter_data = static_cast<T*>( + diff_filter.GetOpMem().get_data_handle()); + } else { + diff_filter_data = static_cast<T*>(const_cast<T*>( + diff_filter_tensor->flat<T>().data())); + } + + // Execute convolution filter bwd + if (biasEnabled) { + T* diff_bias_data = static_cast<T*>(const_cast<T*>( + diff_bias_tensor->flat<T>().data())); + conv2d_bwd_filter->Execute(src_data, diff_filter_data, + diff_bias_data, diff_dst_data); + } else { + conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data); + } + + // Reorder diff_filter back to Tensorflow layout if necessary + if (diff_filter_reorder_required) { + std::vector<primitive> net; + diff_filter.InsertReorderToUserMem(&net); + stream(stream::kind::eager).submit(net).wait(); + } + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } + private: + const int kInputIndex_Filter = 1; + const int kInputIndex_InputSizes = 0; const int kDilationH = 0, kDilationW = 1; + engine cpu_engine_ = engine(engine::cpu, 0); + + // Validate input shapes. + // Function asserts that input shapes are valid. void ValidateMklShapes(const MklDnnShape& input_mkl_shape, const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { @@ -454,141 +958,44 @@ class MklConv2DCustomBackpropFilterOp << "Conv2DBackpropFilter: filter should not be in MKL Layout"; } - size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ } - + // Get TensorFlow shape of input tensor. TensorShape MakeInputTfShape(OpKernelContext* context, const Tensor& input_tensor) { size_t input_idx = 0; return GetTfShape(context, input_idx); } + // Get TensorFlow shape of filter tensor. TensorShape MakeFilterTfShape(OpKernelContext* context, const Tensor& filter_tensor) { TensorShape filter_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true); CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(), - &filter_tf_shape) - .ok(), - true); + &filter_tf_shape).ok(), true); return filter_tf_shape; } + // Get Tensorflow shape of output tensor (diff_filter), + // which is same as shape of filter. TensorShape GetOutputTfShape(const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& outbprop_shape) { - // Shape of output of Conv2DBackpropFilter is same as shape of filter. return filter_shape; } + // Get the shape of output (diff_filter) in MKL-DNN order. + // Computes shape of output from input shape (fwd_input_dims) + // and filter shape (fwd_filter_dims). const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, const memory::dims& fwd_filter_dims) { - // Shape of output of Conv2DBackpropFilter is same as shape of filter. return fwd_filter_dims; } + // Output layout is Tensorflow's filter layout (HWIO). memory::format GetOutputFormat(const memory::format data_format) { - // Output layout is Tensorflow's filter layout (HWIO). return memory::format::hwio; } - void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine, - const convolution_forward::primitive_desc& conv_fwd_pd, - MklDnnData<T>* input, MklDnnData<T>* filter, - MklDnnData<T>* outbackprop, MklDnnData<T>* output, - Tensor** output_tensor, - const memory::dims& strides, - const memory::dims& dilations, - const memory::dims& padding_l, - const memory::dims& padding_r, padding_kind padding, - const memory::dims& bwd_output_dims, - memory::format bwd_output_format) { - CHECK_NOTNULL(context); - CHECK_NOTNULL(input); - CHECK_NOTNULL(filter); - CHECK_NOTNULL(outbackprop); - CHECK_NOTNULL(output); - CHECK_NOTNULL(output_tensor); - - MklDnnData<T>* bias_grad = nullptr; - int depth = 0; - if (biasEnabled) { - // Data structure for bias_grad - bias_grad = new MklDnnData<T>(&cpu_engine); - TensorShape obp_tf_shape = GetTfShape(context, 2); - depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat() == - FORMAT_NCHW) - ? obp_tf_shape.dim_size(1) - : obp_tf_shape.dim_size(3); - memory::dims bias_grad_dims = {depth}; - bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x); - } - - if (biasEnabled && (bias_grad != nullptr)) { - // Create convolution backward weights with bias primitive. - // Use dilated convolution in case dilate rates are greater than zero. - auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ? - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - bias_grad->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, - dilations, padding_l, padding_r, padding) : - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - bias_grad->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), - strides, padding_l, padding_r, padding); - auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc, - cpu_engine, - conv_fwd_pd); - - // Allocate output tensor. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, - bwd_output_format, output_tensor); - - CHECK_NOTNULL(*output_tensor); - // Set buffer handle using allocated output tensor. - output->SetUsrMemDataHandle(*output_tensor); - - // Allocate bias_grad tensor - TensorShape bias_grad_shape({depth}); - Tensor* bias_grad_tensor = nullptr; - AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor); - memory::dims bias_grad_dims = {depth}; - // Since Bias is 1D, we use format::x from MKLDNN to represent it. - auto bias_grad_md = - memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x); - bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor); - bias_grad->SetUsrMemDataHandle(bias_grad_tensor); - - PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output, - bias_grad); - } else { - // Create convolution backward weights primitive. - // Use dilated convolution in case dilate rates are greater than zero. - auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ? - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, - dilations, padding_l, padding_r, padding) : - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), - strides, padding_l, padding_r, padding); - auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc, - cpu_engine, - conv_fwd_pd); - - // Allocate output tensor. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, - bwd_output_format, output_tensor); - - CHECK_NOTNULL(*output_tensor); - // Set buffer handle using allocated output tensor. - output->SetUsrMemDataHandle(*output_tensor); - PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output); - } - } - // Allocate output tensor. void AllocateOutputTensor( OpKernelContext* context, @@ -623,40 +1030,8 @@ class MklConv2DCustomBackpropFilterOp MklDnnShape bias_grad_mkl_shape; bias_grad_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape, - bias_grad_mkl_shape); - } - - // Prepare and execute net - checks for input and output reorders. - void PrepareAndExecutePrimitive( - const convolution_backward_weights::primitive_desc& conv_pd, - MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output, - MklDnnData<T>* bias_grad = nullptr) { - // Create reorders between user layout and MKL layout if it is needed and - // add it to the net before convolution. - std::vector<primitive> net; - input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net); - obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net); - - // For BackpropFilter, we convert the output tensor back in Tensorflow - // layout. - bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_pd.diff_weights_primitive_desc()); - - if (biasEnabled && (bias_grad != nullptr)) { - net.push_back(convolution_backward_weights( - conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(), - bias_grad->GetOpMem())); - } else { - net.push_back(convolution_backward_weights( - conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem())); - } - - if (output_reorder_required) { - output->InsertReorderToUserMem(&net); - } - - stream(stream::kind::eager).submit(net).wait(); + AllocateOutputSetMklShape(context, 1, bias_grad_tensor, + bias_grad_shape, bias_grad_mkl_shape); } }; diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 21b18f9119..0af4568b47 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -55,9 +55,246 @@ using mkldnn::stream; #endif namespace tensorflow { - typedef Eigen::ThreadPoolDevice CPUDevice; +#ifndef INTEL_MKL_ML + +/// utility classes enabling primitive reuse for backward conv2d ops. +struct MklConvBwdInputParams { + memory::dims diff_src_dims; + memory::dims filter_dims; + memory::dims diff_dst_dims; + memory::dims strides; + memory::dims dilations; + memory::dims padding_left; + memory::dims padding_right; + padding_kind padding; + + MklConvBwdInputParams(memory::dims diff_src_dims, + memory::dims filter_dims, memory::dims diff_dst_dims, + memory::dims strides, memory::dims dilations, + memory::dims padding_left, memory::dims padding_right, + padding_kind padding) : + diff_src_dims(diff_src_dims), filter_dims(filter_dims), + diff_dst_dims(diff_dst_dims), strides(strides), + dilations(dilations), padding_left(padding_left), + padding_right(padding_right), padding(padding) { + } +}; + +template <typename T> +class MklConv2DBwdInputPrimitive : public MklPrimitive { + public: + explicit MklConv2DBwdInputPrimitive( + const MklConvBwdInputParams& convBwdInputDims) : + cpu_engine_(engine::cpu, 0) { + context_.bwd_input_stream.reset(new stream(stream::kind::eager)); + + // create conv primitive + if (context_.conv_bwd_input == nullptr) { + Setup(convBwdInputDims); + } + } + ~MklConv2DBwdInputPrimitive() {} + + // Convolution backward filter (weights) + // diff_src_data: output data buffer of diff_src + // filter_data: input data buffer of filter (weights) + // diff_dst_data: input data buffer of dst + // Bias does not matter here + void Execute(const T* diff_src_data, + const T* filter_data, const T* diff_dst_data) { + context_.diff_src_mem->set_data_handle( + static_cast<T*>(const_cast<T*>(diff_src_data))); + context_.filter_mem->set_data_handle( + static_cast<T*>(const_cast<T*>(filter_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<T*>(const_cast<T*>(diff_dst_data))); + + context_.bwd_input_stream->submit(context_.bwd_input_primitives); + + // set back data handle + context_.diff_src_mem->set_data_handle(DummyData); + context_.filter_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + return; + } + + memory::format GetFilterMemoryFormat() const { + return context_.filter_fmt; + } + + memory::format GetDiffDstMemoryFormat() const { + return context_.diff_dst_fmt; + } + + std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> + GetPrimitiveDesc() const { + return context_.bwd_input_pd; + } + + private: + // Primitive reuse context for Conv2D Bwd Input op + struct ConvBwdInputContext { + // expected memory format for this primitive instance + memory::format filter_fmt; + memory::format diff_dst_fmt; + + // MKLDNN memory + std::shared_ptr<mkldnn::memory> diff_src_mem; + std::shared_ptr<mkldnn::memory> filter_mem; + std::shared_ptr<mkldnn::memory> diff_dst_mem; + + // convolution primitive + std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> + bwd_input_pd; + std::shared_ptr<mkldnn::primitive> conv_bwd_input; + + // desc & prmitive desc + std::shared_ptr<mkldnn::convolution_backward_data::desc> bwd_input_desc; + std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc; + std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd; + + // memory desc: forward & backward can share same memory::desc + std::shared_ptr<memory::desc> diff_src_md; + std::shared_ptr<memory::desc> filter_md; + std::shared_ptr<memory::desc> diff_dst_md; + + // MKL pipeline + std::shared_ptr<mkldnn::stream> bwd_input_stream; + std::vector<mkldnn::primitive> bwd_input_primitives; + + ConvBwdInputContext() : + filter_fmt(memory::format::any), diff_dst_fmt(memory::format::any), + diff_src_mem(nullptr), filter_mem(nullptr), diff_dst_mem(nullptr), + bwd_input_pd(nullptr), conv_bwd_input(nullptr), + bwd_input_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), + diff_src_md(nullptr), filter_md(nullptr), diff_dst_md(nullptr), + bwd_input_stream(nullptr) { + } + }; + + + void Setup(const MklConvBwdInputParams& convBwdInputDims) { + // create memory descriptors for convolution data w/ no specified format + context_.diff_src_md.reset(new memory::desc( + {convBwdInputDims.diff_src_dims}, + MklDnnType<T>(), memory::format::any)); + context_.filter_md.reset(new memory::desc( + {convBwdInputDims.filter_dims}, + MklDnnType<T>(), memory::format::any)); + context_.diff_dst_md.reset(new memory::desc( + {convBwdInputDims.diff_dst_dims}, + MklDnnType<T>(), memory::format::any)); + + // create convolution primitives + context_.bwd_input_desc.reset(new convolution_backward_data::desc( + convolution_direct, *context_.diff_src_md, *context_.filter_md, + *context_.diff_dst_md, convBwdInputDims.strides, + convBwdInputDims.dilations, convBwdInputDims.padding_left, + convBwdInputDims.padding_right, convBwdInputDims.padding)); + + context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward, + convolution_direct, *context_.diff_src_md, *context_.filter_md, + *context_.diff_dst_md, convBwdInputDims.strides, + convBwdInputDims.dilations, convBwdInputDims.padding_left, + convBwdInputDims.padding_right, convBwdInputDims.padding)); + + context_.fwd_pd.reset(new convolution_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + + // create backward conv prim desc + context_.bwd_input_pd.reset( + new convolution_backward_data::primitive_desc( + *context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd)); + + // create memory primitive based on dummy data + context_.diff_src_mem.reset(new memory( + context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData)); + context_.filter_mem.reset(new memory( + context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData)); + context_.diff_dst_mem.reset(new memory( + context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData)); + + // store the expected memory format + context_.filter_fmt = static_cast<memory::format>( + context_.bwd_input_pd.get()->weights_primitive_desc().desc().data.format); + context_.diff_dst_fmt = static_cast<memory::format>( + context_.bwd_input_pd.get()->diff_dst_primitive_desc().desc().data.format); + + // create convolution primitive and add it to net + context_.conv_bwd_input.reset(new convolution_backward_data( + *context_.bwd_input_pd, *context_.diff_dst_mem, + *context_.filter_mem, *context_.diff_src_mem)); + + context_.bwd_input_primitives.push_back(*context_.conv_bwd_input); + } + + struct ConvBwdInputContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { + private: + MklConv2DBwdInputPrimitiveFactory() {} + ~MklConv2DBwdInputPrimitiveFactory() {} + + public: + static MklConv2DBwdInputPrimitive<T>* Get( + const MklConvBwdInputParams& convBwdInputDims) { + MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr; + + // look into the pool for reusable primitive + conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> ( + MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput( + convBwdInputDims)); + + if (conv2d_bwd_input == nullptr) { + conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>( + convBwdInputDims); + MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput( + convBwdInputDims, conv2d_bwd_input); + } + return conv2d_bwd_input; + } + + private: + static MklConv2DBwdInputPrimitiveFactory& GetInstance() { + static MklConv2DBwdInputPrimitiveFactory instance_; + return instance_; + } + + static std::string CreateKey( + const MklConvBwdInputParams& convBwdInputDims) { + std::string prefix = "conv2d_bwd_input"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(convBwdInputDims.diff_src_dims); + key_creator.AddAsKey(convBwdInputDims.filter_dims); + key_creator.AddAsKey(convBwdInputDims.diff_dst_dims); + key_creator.AddAsKey(convBwdInputDims.strides); + key_creator.AddAsKey(convBwdInputDims.dilations); + key_creator.AddAsKey(convBwdInputDims.padding_left); + key_creator.AddAsKey(convBwdInputDims.padding_right); + return key_creator.GetKey(); + } + + MklPrimitive* GetConv2dBwdInput( + const MklConvBwdInputParams& convBwdInputDims) { + std::string key = CreateKey(convBwdInputDims); + return this->GetOp(key); + } + + void SetConv2dBwdInput( + const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) { + std::string key = CreateKey(convBwdInputDims); + this->SetOp(key, op); + } +}; + +#endif + #ifdef INTEL_MKL_ML template <typename Device, class T> @@ -365,13 +602,173 @@ class MklConv2DCustomBackpropInputOp : public MklConv2DBackpropCommonOp<Device, T> { public: explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) {} + : MklConv2DBackpropCommonOp<Device, T>(context) { + } + ~MklConv2DCustomBackpropInputOp() {} + void Compute(OpKernelContext* context) { + try { + MklDnnData<T> filter(&cpu_engine); + MklDnnData<T> diff_dst(&cpu_engine); + + // Input tensors + const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; + const Tensor& src_tensor = MklGetInput(context, kInputIdx); + const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); + const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx); + + MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape; + GetMklShape(context, kInputIdx, &src_mkl_shape); + GetMklShape(context, kFilterIdx, &filter_mkl_shape); + GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape); + // Allow operator-specific sanity checking of shapes. + ValidateMklShapes(src_mkl_shape, filter_mkl_shape, + diff_dst_mkl_shape); + + // Allow operator-specific generation of shapes. + // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a + // tensor containing shape of filter. So filter.shape() is not + // a correct way to get filter shape. These operator-specific calls + // allow this class to handle this case. + TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor); + TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); + TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx); + + // Corner cases: output with 0 elements and 0 batch size. + Tensor* diff_src_tensor = nullptr; + if (src_tf_shape.num_elements() == 0 || + filter_tf_shape.num_elements() == 0 || + diff_dst_tf_shape.num_elements() == 0) { + MklDnnShape diff_src_mkl_shape; + diff_src_mkl_shape.SetMklTensor(false); + TensorShape diff_src_tf_shape = GetOutputTfShape( + src_tf_shape, filter_tf_shape, diff_dst_tf_shape); + const int kOutputIdx = 0; + AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor, + diff_src_tf_shape, diff_src_mkl_shape); + CHECK_NOTNULL(diff_src_tensor); + + // if output tensor has more than 0 elements, we need to 0 them out. + auto diff_src_data = diff_src_tensor->flat<T>().data(); + for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) { + diff_src_data[i] = 0; + } + return; + } + // By default, all dims are in MKL order. Only dims in TF order + // are those with postfix tf_order. + memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; + memory::dims padding_left, padding_right, dilations, strides; + memory::dims fwd_output_dims, fwd_output_dims_tf_order; + + // Get forward convolution parameters. + MklDnnConvUtil conv_utl(context, this->strides_, this->padding_, + this->data_format_, this->dilations_); + conv_utl.GetConvFwdSizesInMklOrder( + src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims, + &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims, + &padding_left, &padding_right); + if (!context->status().ok()) return; + + // Create Convolution forward descriptor since Convolution backward + // API needs it. For that, we first need to create input, filter + // and output memory descriptors. + auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_); + + // If filter is in MKL layout, then simply grab filter layout; + // otherwise, construct filter in TF layout. + // For TF layout, filter is in HWIO format. + auto fwd_filter_md = filter_mkl_shape.IsMklTensor() + ? filter_mkl_shape.GetMklLayout() + : memory::desc(fwd_filter_dims, MklDnnType<T>(), + memory::format::hwio); + + conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + if (!context->status().ok()) return; + auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor() + ? diff_dst_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, + MklDnnType<T>(), tf_fmt); + + dilations[kDilationH] -= 1; + dilations[kDilationW] -= 1; + + MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr; + conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims, + diff_dst_dims, strides, dilations, padding_left, padding_right, + TFPaddingToMklDnnPadding(this->padding_)); + conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get( + convBwdInputDims); + auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc(); + + // allocate output tensor + auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc(); + auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); + auto bwd_diff_src_format = GetOutputFormat(tf_fmt); + MklDnnShape diff_src_mkl_shape; + diff_src_mkl_shape.SetMklTensor(true); + diff_src_mkl_shape.SetMklLayout(&diff_src_pd); + diff_src_mkl_shape.SetElemType(MklDnnType<T>()); + diff_src_mkl_shape.SetTfLayout(bwd_diff_src_dims.size(), + bwd_diff_src_dims, bwd_diff_src_format); + TensorShape diff_src_tf_shape; + diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T)); + AllocateOutputSetMklShape(context, 0, &diff_src_tensor, + diff_src_tf_shape, diff_src_mkl_shape); + + T *diff_src_data = static_cast<T*>(const_cast<T*>( + diff_src_tensor->flat<T>().data())); + + // check if filter and diff_dst need reorder + std::vector<primitive> net; + T* filter_data = nullptr; + if (fwd_filter_md.data.format != + conv2d_bwd_input->GetFilterMemoryFormat()) { + filter.SetUsrMem(fwd_filter_md, &filter_tensor); + filter.CheckReorderToOpMem( + bwd_input_pd->weights_primitive_desc(), + &net); + filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); + } else { + filter_data = static_cast<T*>(const_cast<T*>( + filter_tensor.flat<T>().data())); + } + + T* diff_dst_data = nullptr; + if (diff_dst_md.data.format != + conv2d_bwd_input->GetDiffDstMemoryFormat()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + diff_dst.CheckReorderToOpMem( + bwd_input_pd->diff_dst_primitive_desc(), &net); + diff_dst_data = static_cast<T*>( + diff_dst.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast<T*>(const_cast<T*>( + diff_dst_tensor.flat<T>().data())); + } + stream(stream::kind::eager).submit(net).wait(); + + // execute convolution input bwd + conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } + private: - const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0, - kInputIndex_OutBackProp = 2; + const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0; const int kDilationH = 0, kDilationW = 1; + engine cpu_engine = engine(engine::cpu, 0); + + // Validate input shapes. + // Function asserts that input shapes are valid. void ValidateMklShapes(const MklDnnShape& input_mkl_shape, const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { @@ -382,8 +779,7 @@ class MklConv2DCustomBackpropInputOp << "Conv2DBackpropInput: input should not be in MKL Layout"; } - size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; } - + // Get TensorFlow shape of input tensor. TensorShape MakeInputTfShape(OpKernelContext* context, const Tensor& input_tensor) { TensorShape input_tf_shape; @@ -395,72 +791,32 @@ class MklConv2DCustomBackpropInputOp return input_tf_shape; } + // Get TensorFlow shape of filter tensor. TensorShape MakeFilterTfShape(OpKernelContext* context, const Tensor& filter_tensor) { return GetTfShape(context, kInputIndex_Filter); } + // Get the Tensorflow shape of Output (diff_src), + // which is same as shape of Conv2D 'input'. TensorShape GetOutputTfShape(const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& outbprop_shape) { - // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'. return input_shape; } + // Get the Tensorflow shape of Output (diff_src), + // which is same as shape of Conv2D 'input'. const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, const memory::dims& fwd_filter_dims) { - // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'. return fwd_input_dims; } + // Output layout is Tensorflow's layout in data format order. memory::format GetOutputFormat(const memory::format data_format) { - // Output layout is Tensorflow's layout in data format order. return data_format; } - void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine, - const convolution_forward::primitive_desc& conv_fwd_pd, - MklDnnData<T>* input, MklDnnData<T>* filter, - MklDnnData<T>* outbackprop, MklDnnData<T>* output, - Tensor** output_tensor, - const memory::dims& strides, - const memory::dims& dilations, - const memory::dims& padding_l, - const memory::dims& padding_r, padding_kind padding, - const memory::dims& bwd_output_dims, - memory::format bwd_output_format) { - CHECK_NOTNULL(context); - CHECK_NOTNULL(input); - CHECK_NOTNULL(filter); - CHECK_NOTNULL(outbackprop); - CHECK_NOTNULL(output); - CHECK_NOTNULL(output_tensor); - - // Create convolution backward data primitive. - // Use dilated convolution in case dilate rates are greater than zero. - auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ? - convolution_backward_data::desc(convolution_direct, - output->GetOpMemDesc(), filter->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, - dilations, padding_l, padding_r, padding): - convolution_backward_data::desc(convolution_direct, - output->GetOpMemDesc(), filter->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), - strides, padding_l, padding_r, padding); - - auto bwd_pd = convolution_backward_data::primitive_desc( - bwd_desc, cpu_engine, conv_fwd_pd); - - // Allocate output tensor in TensorFlow and MKL layout. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format, - output_tensor); - CHECK_NOTNULL(*output_tensor); - // Set buffer handle using allocated output tensor. - output->SetUsrMemDataHandle(*output_tensor); - - PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output); - } - // Allocate output tensor. void AllocateOutputTensor( OpKernelContext* context, @@ -487,22 +843,6 @@ class MklConv2DCustomBackpropInputOp AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, output_mkl_shape); } - - // Prepare and execute net - checks for input and output reorders. - void PrepareAndExecutePrimitive( - const convolution_backward_data::primitive_desc& conv_pd, - MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) { - // Create reorders between user layout and MKL layout if it is needed and - // add it to the net before convolution. - std::vector<primitive> net; - filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net); - obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net); - - net.push_back(convolution_backward_data( - conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); - - stream(stream::kind::eager).submit(net).wait(); - } }; #endif // INTEL_MKL_ML diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 8333a09316..5e1a5001dc 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -19,6 +19,7 @@ limitations under the License. #include <limits> #include <string> #include <vector> +#include <memory> #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -349,6 +350,7 @@ class MklDnnConvUtil { } }; + ///////////////////////////////////////////////////////////////////// /// Common class that implements Conv2DBackpropFilter and Input ///////////////////////////////////////////////////////////////////// @@ -388,227 +390,17 @@ class MklConv2DBackpropCommonOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } - void Compute(OpKernelContext* context) override { - try { - auto cpu_engine = engine(engine::cpu, 0); - - // Prepare common tensors for Conv2DBackpropInput and - // Conv2DBackpropFilter. - MklDnnData<T> input(&cpu_engine); - MklDnnData<T> filter(&cpu_engine); - MklDnnData<T> outbackprop(&cpu_engine); - MklDnnData<T> output(&cpu_engine); - - // Input tensors - const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; - const Tensor& input_tensor = MklGetInput(context, kInputIdx); - const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); - const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx); - - MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape; - GetMklShape(context, kInputIdx, &input_mkl_shape); - GetMklShape(context, kFilterIdx, &filter_mkl_shape); - GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape); - // Allow operator-specific sanity checking of shapes. - ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape); - - // Allow operator-specific generation of shapes. - // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a - // tensor containing shape of filter. So filter.shape() is not - // a correct way to get filter shape. These operator-specific calls - // allow this class to handle this case. - TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor); - TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); - TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx); - - // Corner cases: output with 0 elements and 0 batch size. - Tensor* output_tensor = nullptr; - if (input_tf_shape.num_elements() == 0 || - filter_tf_shape.num_elements() == 0 || - outbprop_tf_shape.num_elements() == 0) { - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape = GetOutputTfShape( - input_tf_shape, filter_tf_shape, outbprop_tf_shape); - const int kOutputIdx = 0; - AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor, - output_tf_shape, output_mkl_shape); - CHECK_NOTNULL(output_tensor); - - // if output tensor has more than 0 elements, we need to 0 them out. - for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) { - output_tensor->flat<T>().data()[i] = 0; - } - - return; - } - - // By default, all dims are in MKL order. Only dims in TF order - // are those with prefix tf_order. - memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims; - memory::dims padding_l, padding_r, dilations, strides, fwd_output_dims; - memory::dims fwd_output_dims_tf_order; - - // Get forward convolution parameters. - MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_, - dilations_); - conv_utl.GetConvFwdSizesInMklOrder( - input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims, - &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims, - &padding_l, &padding_r); - if (!context->status().ok()) return; - - // Create Convolution forward descriptor since Convolution backward - // API needs it. For that, we first need to create input, filter - // and output memory descriptors. - auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_); - // If input is in MKL layout, then simply grab input layout; otherwise, - // construct input TF layout. For TF layout, although input shape - // required is in MKL-DNN order, the layout is Tensorflow's layout - // (NHWC or NCHW depending on data format). - auto fwd_input_md = - input_mkl_shape.IsMklTensor() - ? input_mkl_shape.GetMklLayout() - : memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt); - // If filter is in MKL layout, then simply grab filter layout; otherwise - // construct filter in TF layout. For TF layout, filter is in HWIO format. - auto fwd_filter_md = filter_mkl_shape.IsMklTensor() - ? filter_mkl_shape.GetMklLayout() - : memory::desc(fwd_filter_dims, MklDnnType<T>(), - memory::format::hwio); - // Tensorflow Output of Conv2D is in data_format order. - auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt); - - const int kDilationH = 0, kDilationW = 1; - dilations[kDilationH] -= 1; - dilations[kDilationW] -= 1; - auto fwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0)? - convolution_forward::desc(prop_kind::forward, - convolution_direct, fwd_input_md, - fwd_filter_md, fwd_out_md, - strides, dilations, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)) : - convolution_forward::desc(prop_kind::forward, - convolution_direct, fwd_input_md, - fwd_filter_md, fwd_out_md, - strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); - auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); - - // Create memory for user data. Describe how the inputs and outputs of - // Convolution look like. Also specify buffers containing actual input - // and output data. - - // Since this is a common class for both Conv2DBackpropFilter and - // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for - // Conv2DBackpropInput) and for filter tensor (for - // conv2DBackpropFilter) depending on which tensor is int32 type. - size_t input_with_sizes = GetInputTensorIndexWithSizes(); - if (input_with_sizes != kInputIdx) { - // Shape of Conv2DBackpropFilter's input is same as Conv2D input. - input.SetUsrMem(fwd_input_md, &input_tensor); - } else if (input_with_sizes != kFilterIdx) { - // Shape of Conv2DBackpropInput's filter is same as Conv2D filter. - filter.SetUsrMem(fwd_filter_md, &filter_tensor); - } - - conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims); - if (!context->status().ok()) return; - if (outbprop_mkl_shape.IsMklTensor()) { - // If outbackprop is in Mkl layout, then simply grab it. - auto outbprop_md = outbprop_mkl_shape.GetMklLayout(); - outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor); - } else { - // If outbackprop is in TensorFlow layout, then we need to create memory - // descriptor for it. Outbackprop shape is data format order. - outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor); - } - - // Operator specific call to get output shape and data_format. - auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims); - auto bwd_output_format = GetOutputFormat(tf_fmt); - output.SetUsrMem(bwd_output_dims, bwd_output_format); - - // Create memory descriptors for convolution data w/ no specified format. - input.SetOpMemDesc(fwd_input_dims, memory::format::any); - filter.SetOpMemDesc(fwd_filter_dims, memory::format::any); - outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any); - output.SetOpMemDesc(bwd_output_dims, memory::format::any); - - // Operator-specific call to create and execute primitive. - CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter, - &outbackprop, &output, &output_tensor, - strides, dilations, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_), - bwd_output_dims, bwd_output_format); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); - } - } - - /// Pure virtual function to allow operator to check for validity of input - /// shapes. Function asserts that input shapes are valid. - virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape, - const MklDnnShape& filter_mkl_shape, - const MklDnnShape& outbprop_mkl_shape) = 0; - - /// Operator-specific function that returns index of input that is - /// representing input sizes. For Conv2DBackpropFilter it returns 1 since - /// filter for this operator is filter shape. For Conv2DBackpropInput it - /// returns 0 (for input). - virtual size_t GetInputTensorIndexWithSizes() = 0; - - /// Get TensorFlow shape of input tensor. - virtual TensorShape MakeInputTfShape(OpKernelContext* context, - const Tensor& input_tensor) = 0; - - /// Get TensorFlow shape of filter tensor. - virtual TensorShape MakeFilterTfShape(OpKernelContext* context, - const Tensor& filter_tensor) = 0; - - /// Get the TensorFlow shape of output tensor. - virtual TensorShape GetOutputTfShape(const TensorShape& input_shape, - const TensorShape& filter_shape, - const TensorShape& outbprop_shape) = 0; - - /// Get shape of output in MKL-DNN order. Computes shape of output from - /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims). - virtual const memory::dims& GetOutputDims( - const memory::dims& fwd_input_dims, - const memory::dims& fwd_filter_dims) = 0; - - /// Get data_format of output in MKL-DNN order. If output data format is - /// same as input data format, then it simply returns value of data_format - /// parameter as it is. - virtual memory::format GetOutputFormat(const memory::format data_format) = 0; - - /// Create and execute the primitive storing output in the output_tensor. - virtual void CreatePrimitive(OpKernelContext* context, - const engine& cpu_engine, - const convolution_forward::primitive_desc& conv_fwd_pd, - MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop, - MklDnnData<T>* output, Tensor** output_tensor, const memory::dims& strides, - const memory::dims& dilations, const memory::dims& padding_l, - const memory::dims& padding_r, padding_kind padding, - const memory::dims& bwd_output_dims, - memory::format bwd_output_format) = 0; - - // Get the data_format {NCHW, NHWC} - TensorFormat GetTFDataFormat() { return data_format_; } - - private: + protected: + // data members accessible to derived classes. std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_; // NCHW or NHWC }; + #endif // INTEL_MKL_ML + ///////////////////////////////////////////////////////////////////// /// Dummy Mkl op that is just used for operators that are intermediate /// output of node fusion in the graph |