diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 78 |
1 files changed, 40 insertions, 38 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 9080bf7be8..f291281108 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -45,12 +45,12 @@ limitations under the License. #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; +using mkldnn::convolution_forward; using mkldnn::convolution_backward_weights; using mkldnn::convolution_direct; -using mkldnn::convolution_forward; #endif @@ -463,13 +463,12 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Generate input shapes. TensorShape filter_shape; - OP_REQUIRES( - context, TensorShapeUtils::IsVector(filter_tensor.shape()), - errors::InvalidArgument( + OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()), + errors::InvalidArgument( "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", filter_tensor.dims())); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - filter_tensor.vec<int32>(), &filter_shape)); + filter_tensor.vec<int32>(), &filter_shape)); TensorShape input_shape = input_tensor.shape(); TensorShape obp_shape = obp_tensor.shape(); @@ -481,26 +480,27 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims, - &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); - auto fwd_src_md = - memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format); - auto fwd_filter_md = - memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio); - auto fwd_out_md = - memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format); - auto fwd_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md, - fwd_out_md, strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(), + memory::format::hwio); + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Allocate output tensor and shape @@ -537,22 +537,23 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Create convolution backward weights primitive. - auto bwd_desc = convolution_backward_weights::desc( - convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(), - outbackprop.GetOpMemDesc(), strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto bwd_desc = convolution_backward_weights::desc(convolution_direct, + input.GetOpMemDesc(), output.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); - auto bwd_pd = convolution_backward_weights::primitive_desc( - bwd_desc, cpu_engine, fwd_pd); + auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc, + cpu_engine, + fwd_pd); PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); } } @@ -563,8 +564,9 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_weights::primitive_desc& conv_pd, - MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output) { + const convolution_backward_weights::primitive_desc& conv_pd, + MklDnnData<T>* input, MklDnnData<T>* obp, + MklDnnData<T>* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector<primitive> net; @@ -575,10 +577,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // output side, we will prepare reorder primitive in case output // reorder to user memory is required. bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_pd.diff_weights_primitive_desc()); + conv_pd.diff_weights_primitive_desc()); - net.push_back(convolution_backward_weights( - conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(), + obp->GetOpMem(), output->GetOpMem())); // Insert reorder primitive in the net for output reorder if reorder is // required. |