aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc82
1 files changed, 42 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 369f632fb4..a9872b8d6d 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <string.h>
#include <map>
-#include <string>
#include <vector>
+#include <string>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -46,11 +46,11 @@ limitations under the License.
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
-using mkldnn::prop_kind;
using mkldnn::stream;
+using mkldnn::prop_kind;
-using mkldnn::convolution_direct;
using mkldnn::convolution_forward;
+using mkldnn::convolution_direct;
#endif
namespace tensorflow {
@@ -523,16 +523,19 @@ class MklConv2DOp : public OpKernel {
// Get shapes of input tensors in MKL-DNN order
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(
- src_tensor.shape(), filter_tensor.shape(), &src_dims, &filter_dims,
- &strides, &output_dims_tf_order, &output_dims_mkl_order, &padding_l,
- &padding_r);
+ conv_utl.GetConvFwdSizesInMklOrder(src_tensor.shape(),
+ filter_tensor.shape(),
+ &src_dims, &filter_dims, &strides,
+ &output_dims_tf_order,
+ &output_dims_mkl_order, &padding_l,
+ &padding_r);
if (!context->status().ok()) return;
// Check for corner case - if there is nothing to compute, return.
- TensorShape tf_output_shape(
- {output_dims_tf_order[0], output_dims_tf_order[1],
- output_dims_tf_order[2], output_dims_tf_order[3]});
+ TensorShape tf_output_shape({output_dims_tf_order[0],
+ output_dims_tf_order[1],
+ output_dims_tf_order[2],
+ output_dims_tf_order[3]});
Tensor* output_tensor = nullptr;
MklShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
@@ -569,13 +572,13 @@ class MklConv2DOp : public OpKernel {
// the layout is Tensorflow's layout (NHWC or NCHW depending on data
// format).
src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_),
- const_cast<void*>(
- static_cast<const void*>(src_tensor.flat<T>().data())));
+ const_cast<void*>(static_cast<const void*>(
+ src_tensor.flat<T>().data())));
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO).
filter.SetUsrMem(filter_dims, memory::format::hwio,
const_cast<void*>(static_cast<const void*>(
- filter_tensor.flat<T>().data())));
+ filter_tensor.flat<T>().data())));
// Although output shape (output_dims) required is in MKL-DNN order,
// layout is Tensorflow's layout (NHWC or NCHW depending on data format).
output.SetUsrMem(output_dims_mkl_order,
@@ -595,36 +598,36 @@ class MklConv2DOp : public OpKernel {
const Tensor& bias_tensor = MklGetInput(context, 2);
bias.SetUsrMem(bias_size, memory::format::x,
const_cast<void*>(static_cast<const void*>(
- bias_tensor.flat<T>().data())));
+ bias_tensor.flat<T>().data())));
bias.SetOpMemDesc(bias_size, memory::format::any);
// Create convolution primitive with Bias.
- auto conv_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
- filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(),
- strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto conv_desc = convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
+ bias.GetOpMemDesc(), output.GetOpMemDesc(), strides,
+ padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
- auto conv_prim_desc =
- convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
+ cpu_engine);
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
} else {
// Create convolution primitive without Bias.
- auto conv_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
- filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto conv_desc = convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
+ output.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
- auto conv_prim_desc =
- convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
+ cpu_engine);
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
}
- } catch (mkldnn::error& e) {
+ } catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + std::string(e.message) + ", in file " +
- std::string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ ", message: " + std::string(e.message) +
+ ", in file " + std::string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:", error_msg));
}
}
@@ -635,9 +638,9 @@ class MklConv2DOp : public OpKernel {
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecuteNet(
- const convolution_forward::primitive_desc& conv_prim_desc,
- MklDnnData<T>* src, MklDnnData<T>* filter, MklDnnData<T>* bias,
- MklDnnData<T>* output) {
+ const convolution_forward::primitive_desc& conv_prim_desc,
+ MklDnnData<T>* src, MklDnnData<T>* filter,
+ MklDnnData<T>* bias, MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
@@ -648,19 +651,18 @@ class MklConv2DOp : public OpKernel {
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_prim_desc.dst_primitive_desc());
+ conv_prim_desc.dst_primitive_desc());
// Create convolution primitive and add it to net.
if (bias) {
CHECK_EQ(biasEnabled, true);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
- filter->GetOpMem(), bias->GetOpMem(),
- output->GetOpMem()));
+ filter->GetOpMem(), bias->GetOpMem(),
+ output->GetOpMem()));
} else {
CHECK_EQ(biasEnabled, false);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
- filter->GetOpMem(),
- output->GetOpMem()));
+ filter->GetOpMem(), output->GetOpMem()));
}
// Insert reorder primitive in the net for output reorder if reorder is