diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.cc | 82 |
1 files changed, 42 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 369f632fb4..a9872b8d6d 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -18,8 +18,8 @@ limitations under the License. #include <string.h> #include <map> -#include <string> #include <vector> +#include <string> #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -46,11 +46,11 @@ limitations under the License. #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; -using mkldnn::convolution_direct; using mkldnn::convolution_forward; +using mkldnn::convolution_direct; #endif namespace tensorflow { @@ -523,16 +523,19 @@ class MklConv2DOp : public OpKernel { // Get shapes of input tensors in MKL-DNN order MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - src_tensor.shape(), filter_tensor.shape(), &src_dims, &filter_dims, - &strides, &output_dims_tf_order, &output_dims_mkl_order, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(src_tensor.shape(), + filter_tensor.shape(), + &src_dims, &filter_dims, &strides, + &output_dims_tf_order, + &output_dims_mkl_order, &padding_l, + &padding_r); if (!context->status().ok()) return; // Check for corner case - if there is nothing to compute, return. - TensorShape tf_output_shape( - {output_dims_tf_order[0], output_dims_tf_order[1], - output_dims_tf_order[2], output_dims_tf_order[3]}); + TensorShape tf_output_shape({output_dims_tf_order[0], + output_dims_tf_order[1], + output_dims_tf_order[2], + output_dims_tf_order[3]}); Tensor* output_tensor = nullptr; MklShape mkl_output_mkl_shape; mkl_output_mkl_shape.SetMklTensor(false); @@ -569,13 +572,13 @@ class MklConv2DOp : public OpKernel { // the layout is Tensorflow's layout (NHWC or NCHW depending on data // format). src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_), - const_cast<void*>( - static_cast<const void*>(src_tensor.flat<T>().data()))); + const_cast<void*>(static_cast<const void*>( + src_tensor.flat<T>().data()))); // Although filter shape (filter_dims) required is in MKL-DNN order, // the layout is Tensorflow's layout (HWIO). filter.SetUsrMem(filter_dims, memory::format::hwio, const_cast<void*>(static_cast<const void*>( - filter_tensor.flat<T>().data()))); + filter_tensor.flat<T>().data()))); // Although output shape (output_dims) required is in MKL-DNN order, // layout is Tensorflow's layout (NHWC or NCHW depending on data format). output.SetUsrMem(output_dims_mkl_order, @@ -595,36 +598,36 @@ class MklConv2DOp : public OpKernel { const Tensor& bias_tensor = MklGetInput(context, 2); bias.SetUsrMem(bias_size, memory::format::x, const_cast<void*>(static_cast<const void*>( - bias_tensor.flat<T>().data()))); + bias_tensor.flat<T>().data()))); bias.SetOpMemDesc(bias_size, memory::format::any); // Create convolution primitive with Bias. - auto conv_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, src.GetOpMemDesc(), - filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(), - strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto conv_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), + bias.GetOpMemDesc(), output.GetOpMemDesc(), strides, + padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); - auto conv_prim_desc = - convolution_forward::primitive_desc(conv_desc, cpu_engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, + cpu_engine); PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output); } else { // Create convolution primitive without Bias. - auto conv_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, src.GetOpMemDesc(), - filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l, - padding_r, TFPaddingToMklDnnPadding(padding_)); + auto conv_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), + output.GetOpMemDesc(), strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_)); - auto conv_prim_desc = - convolution_forward::primitive_desc(conv_desc, cpu_engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, + cpu_engine); PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output); } - } catch (mkldnn::error& e) { + } catch (mkldnn::error &e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + std::string(e.message) + ", in file " + - std::string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + ", message: " + std::string(e.message) + + ", in file " + std::string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -635,9 +638,9 @@ class MklConv2DOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecuteNet( - const convolution_forward::primitive_desc& conv_prim_desc, - MklDnnData<T>* src, MklDnnData<T>* filter, MklDnnData<T>* bias, - MklDnnData<T>* output) { + const convolution_forward::primitive_desc& conv_prim_desc, + MklDnnData<T>* src, MklDnnData<T>* filter, + MklDnnData<T>* bias, MklDnnData<T>* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector<primitive> net; @@ -648,19 +651,18 @@ class MklConv2DOp : public OpKernel { // output side, we will prepare reorder primitive in case output // reorder to user memory is required. bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_prim_desc.dst_primitive_desc()); + conv_prim_desc.dst_primitive_desc()); // Create convolution primitive and add it to net. if (bias) { CHECK_EQ(biasEnabled, true); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), bias->GetOpMem(), - output->GetOpMem())); + filter->GetOpMem(), bias->GetOpMem(), + output->GetOpMem())); } else { CHECK_EQ(biasEnabled, false); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), - output->GetOpMem())); + filter->GetOpMem(), output->GetOpMem())); } // Insert reorder primitive in the net for output reorder if reorder is |