diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 86 |
1 files changed, 44 insertions, 42 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 4b6bf92e42..4a47d0463e 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -23,8 +23,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include <algorithm> #include <vector> -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -43,16 +41,18 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#include "mkl_dnn.h" +#include "mkl_dnn_types.h" #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; -using mkldnn::convolution_backward_data; -using mkldnn::convolution_direct; using mkldnn::convolution_forward; +using mkldnn::convolution_direct; +using mkldnn::convolution_backward_data; #endif namespace tensorflow { @@ -397,13 +397,12 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Generate input shape. TensorShape input_shape; - OP_REQUIRES( - context, TensorShapeUtils::IsVector(input_tensor.shape()), - errors::InvalidArgument( + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), + errors::InvalidArgument( "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", input_tensor.dims())); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_tensor.vec<int32>(), &input_shape)); + input_tensor.vec<int32>(), &input_shape)); TensorShape filter_shape = filter_tensor.shape(); TensorShape obp_shape = obp_tensor.shape(); @@ -415,26 +414,27 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims, - &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); - auto fwd_src_md = - memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format); - auto fwd_filter_md = - memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio); - auto fwd_out_md = - memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format); - auto fwd_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md, - fwd_out_md, strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(), + memory::format::hwio); + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Allocate output tensor and shape @@ -475,22 +475,23 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Create convolution backward data primitive. - auto bwd_desc = convolution_backward_data::desc( - convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(), - outbackprop.GetOpMemDesc(), strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto bwd_desc = convolution_backward_data::desc(convolution_direct, + output.GetOpMemDesc(), filter.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); - auto bwd_pd = convolution_backward_data::primitive_desc( - bwd_desc, cpu_engine, fwd_pd); + auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, + cpu_engine, + fwd_pd); PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); } } @@ -501,8 +502,9 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_data::primitive_desc& conv_pd, - MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) { + const convolution_backward_data::primitive_desc& conv_pd, + MklDnnData<T>* filter, MklDnnData<T>* obp, + MklDnnData<T>* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector<primitive> net; @@ -512,11 +514,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Memory for output of convolution. Since we may need reorder on the // output side, we will prepare reorder primitive in case output // reorder to user memory is required. - bool output_reorder_required = - output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc()); + bool output_reorder_required = output->PrepareReorderToUserMemIfReq( + conv_pd.diff_src_primitive_desc()); - net.push_back(convolution_backward_data( - conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(), + filter->GetOpMem(), output->GetOpMem())); // Insert reorder primitive in the net for output reorder if reorder is // required. |