aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc51
1 files changed, 38 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 76b9f1798d..df49e03f31 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -267,12 +267,15 @@ class MklConv2DOp : public OpKernel {
mkl_context.MklCreateInputLayouts(context);
+ // Temp tensor used to allocate tmp buffers
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor,
- mkl_tmp_bias_buf_tensor; // Temp tensor used to allocate tmp
- // buffers
- mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor,
+ mkl_tmp_bias_buf_tensor, mkl_tmp_buf_trans_input;
+ mkl_context.MklPrepareConvolutionInputs(context, data_format_,
+ input_in_mkl_format,
+ &mkl_tmp_input_buf_tensor,
&mkl_tmp_filter_buf_tensor,
- &mkl_tmp_bias_buf_tensor);
+ &mkl_tmp_bias_buf_tensor,
+ &mkl_tmp_buf_trans_input);
// Execute convolution
CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res),
@@ -323,39 +326,59 @@ class MklConv2DOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
void MklPrepareConvolutionInputs(OpKernelContext* context,
+ TensorFormat format,
+ bool input_in_mkl_format,
Tensor* mkl_tmp_input_buf_tensor,
Tensor* mkl_tmp_filter_buf_tensor,
- Tensor* mkl_tmp_bias_buf_tensor) {
+ Tensor* mkl_tmp_bias_buf_tensor,
+ Tensor* mkl_tmp_buf_trans_input) {
bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias;
dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias,
mkl_prim_convert_input;
dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias,
- mkl_lt_internal_input;
+ mkl_lt_internal_input, mkl_lt_trans_input;
void *mkl_buf_convert_input, *mkl_buf_convert_filter,
- *mkl_buf_convert_bias;
+ *mkl_buf_convert_bias, *mkl_buf_input;
mkl_prim_convert_filter = nullptr;
mkl_prim_convert_bias = nullptr;
mkl_prim_convert_input = nullptr;
mkl_lt_internal_filter = nullptr;
mkl_lt_internal_bias = nullptr;
mkl_lt_internal_input = nullptr;
+ mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_filter = nullptr;
mkl_buf_convert_bias = nullptr;
+ mkl_buf_input = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- void* mkl_buf_input =
- const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ if (!input_in_mkl_format && format == FORMAT_NHWC) {
+ TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
+ in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
+ MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
+ mkl_buf_input = const_cast<void*>(static_cast<const void*>(
+ mkl_tmp_buf_trans_input->flat<float>().data()));
+ size_t strides[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
+ strides), E_SUCCESS);
+ } else {
+ mkl_buf_input = const_cast<void*>(
+ static_cast<const void*>(input.flat<T>().data()));
+ mkl_lt_trans_input = lt_input;
+ }
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
prim_fwd, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
if (mkl_convert_input) {
- CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
- mkl_lt_internal_input),
- E_SUCCESS);
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+ mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -364,6 +387,8 @@ class MklConv2DOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
+ if (!input_in_mkl_format && format == FORMAT_NHWC)
+ dnnLayoutDelete_F32(mkl_lt_trans_input);
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;