diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-27 16:33:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-27 16:37:09 -0700 |
commit | 50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch) | |
tree | 7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/core/kernels/mkl_conv_ops.cc | |
parent | d6d58a3a1785785679af56c0f8f131e7312b8226 (diff) |
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.cc | 51 |
1 files changed, 38 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 76b9f1798d..df49e03f31 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -267,12 +267,15 @@ class MklConv2DOp : public OpKernel { mkl_context.MklCreateInputLayouts(context); + // Temp tensor used to allocate tmp buffers Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor, - mkl_tmp_bias_buf_tensor; // Temp tensor used to allocate tmp - // buffers - mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor, + mkl_tmp_bias_buf_tensor, mkl_tmp_buf_trans_input; + mkl_context.MklPrepareConvolutionInputs(context, data_format_, + input_in_mkl_format, + &mkl_tmp_input_buf_tensor, &mkl_tmp_filter_buf_tensor, - &mkl_tmp_bias_buf_tensor); + &mkl_tmp_bias_buf_tensor, + &mkl_tmp_buf_trans_input); // Execute convolution CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res), @@ -323,39 +326,59 @@ class MklConv2DOp : public OpKernel { // Compare incoming tensor layouts with MKL preferred layouts and convert // data to the preferred layout if necessary void MklPrepareConvolutionInputs(OpKernelContext* context, + TensorFormat format, + bool input_in_mkl_format, Tensor* mkl_tmp_input_buf_tensor, Tensor* mkl_tmp_filter_buf_tensor, - Tensor* mkl_tmp_bias_buf_tensor) { + Tensor* mkl_tmp_bias_buf_tensor, + Tensor* mkl_tmp_buf_trans_input) { bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias; dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias, mkl_prim_convert_input; dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias, - mkl_lt_internal_input; + mkl_lt_internal_input, mkl_lt_trans_input; void *mkl_buf_convert_input, *mkl_buf_convert_filter, - *mkl_buf_convert_bias; + *mkl_buf_convert_bias, *mkl_buf_input; mkl_prim_convert_filter = nullptr; mkl_prim_convert_bias = nullptr; mkl_prim_convert_input = nullptr; mkl_lt_internal_filter = nullptr; mkl_lt_internal_bias = nullptr; mkl_lt_internal_input = nullptr; + mkl_lt_trans_input = nullptr; mkl_buf_convert_input = nullptr; mkl_buf_convert_filter = nullptr; mkl_buf_convert_bias = nullptr; + mkl_buf_input = nullptr; // Compare with internal layouts and convert if needed const Tensor& input = MklGetInput(context, 0); - void* mkl_buf_input = - const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); + if (!input_in_mkl_format && format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW, + in_sizes[MklDims::N], in_sizes[MklDims::H], + in_sizes[MklDims::W], in_sizes[MklDims::C]); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input)); + MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input); + mkl_buf_input = const_cast<void*>(static_cast<const void*>( + mkl_tmp_buf_trans_input->flat<float>().data())); + size_t strides[4]; + GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes); + CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes, + strides), E_SUCCESS); + } else { + mkl_buf_input = const_cast<void*>( + static_cast<const void*>(input.flat<T>().data())); + mkl_lt_trans_input = lt_input; + } CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, prim_fwd, dnnResourceSrc), E_SUCCESS); mkl_convert_input = - !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); + !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input); if (mkl_convert_input) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, - mkl_lt_internal_input), - E_SUCCESS); + CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, + mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS); AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, &mkl_buf_convert_input); CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, @@ -364,6 +387,8 @@ class MklConv2DOp : public OpKernel { dnnDelete_F32(mkl_prim_convert_input); } dnnLayoutDelete_F32(mkl_lt_internal_input); + if (!input_in_mkl_format && format == FORMAT_NHWC) + dnnLayoutDelete_F32(mkl_lt_trans_input); conv_res[dnnResourceSrc] = (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; |