diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 86 |
1 files changed, 42 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 4a47d0463e..4b6bf92e42 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -23,6 +23,8 @@ limitations under the License. #define EIGEN_USE_THREADS #include <algorithm> #include <vector> +#include "mkl_dnn.h" +#include "mkl_dnn_types.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -41,18 +43,16 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::stream; using mkldnn::prop_kind; +using mkldnn::stream; -using mkldnn::convolution_forward; -using mkldnn::convolution_direct; using mkldnn::convolution_backward_data; +using mkldnn::convolution_direct; +using mkldnn::convolution_forward; #endif namespace tensorflow { @@ -397,12 +397,13 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Generate input shape. TensorShape input_shape; - OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), - errors::InvalidArgument( + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_tensor.shape()), + errors::InvalidArgument( "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", input_tensor.dims())); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_tensor.vec<int32>(), &input_shape)); + input_tensor.vec<int32>(), &input_shape)); TensorShape filter_shape = filter_tensor.shape(); TensorShape obp_shape = obp_tensor.shape(); @@ -414,27 +415,26 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, - &fwd_input_dims, &fwd_filter_dims, - &strides, - &fwd_output_dims_tf_order, - &fwd_output_dims, - &padding_l, &padding_r); + conv_utl.GetConvFwdSizesInMklOrder( + input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims, + &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, + &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); - auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(), - mkl_data_format); - auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(), - memory::format::hwio); - auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), - mkl_data_format); - auto fwd_desc = convolution_forward::desc(prop_kind::forward, - convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, - strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto fwd_src_md = + memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format); + auto fwd_filter_md = + memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio); + auto fwd_out_md = + memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format); + auto fwd_desc = convolution_forward::desc( + prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md, + fwd_out_md, strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Allocate output tensor and shape @@ -475,23 +475,22 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Create convolution backward data primitive. - auto bwd_desc = convolution_backward_data::desc(convolution_direct, - output.GetOpMemDesc(), filter.GetOpMemDesc(), - outbackprop.GetOpMemDesc(), strides, padding_l, - padding_r, TFPaddingToMklDnnPadding(padding_)); + auto bwd_desc = convolution_backward_data::desc( + convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_)); - auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, - cpu_engine, - fwd_pd); + auto bwd_pd = convolution_backward_data::primitive_desc( + bwd_desc, cpu_engine, fwd_pd); PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", - error_msg)); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -502,9 +501,8 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_data::primitive_desc& conv_pd, - MklDnnData<T>* filter, MklDnnData<T>* obp, - MklDnnData<T>* output) { + const convolution_backward_data::primitive_desc& conv_pd, + MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector<primitive> net; @@ -514,11 +512,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Memory for output of convolution. Since we may need reorder on the // output side, we will prepare reorder primitive in case output // reorder to user memory is required. - bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_pd.diff_src_primitive_desc()); + bool output_reorder_required = + output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc()); - net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(), - filter->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_data( + conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); // Insert reorder primitive in the net for output reorder if reorder is // required. |