aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc45
1 files changed, 10 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 45d22556aa..203e694631 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -272,13 +272,11 @@ class MklConv2DOp : public OpKernel {
// Temp tensor used to allocate tmp buffers
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor,
- mkl_tmp_bias_buf_tensor, mkl_tmp_buf_trans_input;
- mkl_context.MklPrepareConvolutionInputs(context, data_format_,
- input_in_mkl_format,
+ mkl_tmp_bias_buf_tensor;
+ mkl_context.MklPrepareConvolutionInputs(context,
&mkl_tmp_input_buf_tensor,
&mkl_tmp_filter_buf_tensor,
- &mkl_tmp_bias_buf_tensor,
- &mkl_tmp_buf_trans_input);
+ &mkl_tmp_bias_buf_tensor);
// Execute convolution
CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res),
@@ -329,59 +327,38 @@ class MklConv2DOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
void MklPrepareConvolutionInputs(OpKernelContext* context,
- TensorFormat format,
- bool input_in_mkl_format,
Tensor* mkl_tmp_input_buf_tensor,
Tensor* mkl_tmp_filter_buf_tensor,
- Tensor* mkl_tmp_bias_buf_tensor,
- Tensor* mkl_tmp_buf_trans_input) {
+ Tensor* mkl_tmp_bias_buf_tensor) {
bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias;
dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias,
mkl_prim_convert_input;
dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias,
- mkl_lt_internal_input, mkl_lt_trans_input;
+ mkl_lt_internal_input;
void *mkl_buf_convert_input, *mkl_buf_convert_filter,
- *mkl_buf_convert_bias, *mkl_buf_input;
+ *mkl_buf_convert_bias;
mkl_prim_convert_filter = nullptr;
mkl_prim_convert_bias = nullptr;
mkl_prim_convert_input = nullptr;
mkl_lt_internal_filter = nullptr;
mkl_lt_internal_bias = nullptr;
mkl_lt_internal_input = nullptr;
- mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_filter = nullptr;
mkl_buf_convert_bias = nullptr;
- mkl_buf_input = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- if (!input_in_mkl_format && format == FORMAT_NHWC) {
- TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
- in_sizes[MklDims::N], in_sizes[MklDims::H],
- in_sizes[MklDims::W], in_sizes[MklDims::C]);
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
- MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
- mkl_buf_input = const_cast<void*>(static_cast<const void*>(
- mkl_tmp_buf_trans_input->flat<float>().data()));
- size_t strides[4];
- GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
- CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
- strides), E_SUCCESS);
- } else {
- mkl_buf_input = const_cast<void*>(
- static_cast<const void*>(input.flat<T>().data()));
- mkl_lt_trans_input = lt_input;
- }
+ void* mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
prim_fwd, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
if (mkl_convert_input) {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
- mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
+ lt_input, mkl_lt_internal_input), E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -390,8 +367,6 @@ class MklConv2DOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
- if (!input_in_mkl_format && format == FORMAT_NHWC)
- dnnLayoutDelete_F32(mkl_lt_trans_input);
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;