aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc69
1 files changed, 51 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index dc6b88e953..ddcf241277 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -206,10 +206,15 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Mkl needs the entities in its native format.
// So create temporary tensors along with buffers to
// convert the received entities.
- Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor;
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor,
+ mkl_tmp_buf_trans_input;
// This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst
- mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor,
- &mkl_tmp_out_backprop_buf_tensor);
+ mkl_context.MklPrepareInputs(context, data_format_,
+ input_in_mkl_format,
+ out_backprop_in_mkl_format,
+ &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_out_backprop_buf_tensor,
+ &mkl_tmp_buf_trans_input);
// Final conv-grad-filter should be in TF layout.
Tensor* grad_filter;
@@ -307,34 +312,58 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
- void MklPrepareInputs(OpKernelContext* context,
+ void MklPrepareInputs(OpKernelContext* context, TensorFormat format,
+ bool input_in_mkl_format,
+ bool out_backprop_in_mkl_format,
Tensor* mkl_tmp_input_buf_tensor,
- Tensor* mkl_tmp_out_backprop_buf_tensor) {
+ Tensor* mkl_tmp_out_backprop_buf_tensor,
+ Tensor* mkl_tmp_buf_trans_input) {
bool mkl_convert_input, mkl_convert_out_backprop;
dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop;
- dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop;
+ dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop,
+ mkl_lt_trans_input;
void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop;
+ void *mkl_buf_input, *mkl_buf_out_backprop;
mkl_prim_convert_input = nullptr;
mkl_prim_convert_out_backprop = nullptr;
mkl_lt_internal_input = nullptr;
mkl_lt_internal_out_backprop = nullptr;
+ mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_out_backprop = nullptr;
+ mkl_buf_input = nullptr;
+ mkl_buf_out_backprop = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- void* mkl_buf_input =
- const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ if (!input_in_mkl_format && format == FORMAT_NHWC){
+ TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
+ in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
+ MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
+ mkl_buf_input = const_cast<void*>(static_cast<const void*>(
+ mkl_tmp_buf_trans_input->flat<float>().data()));
+ size_t strides[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
+ strides), E_SUCCESS);
+ }
+ else {
+ mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ mkl_lt_trans_input = lt_input;
+ }
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
&mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
if (mkl_convert_input) {
- CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
- mkl_lt_internal_input),
- E_SUCCESS);
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+ mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -343,26 +372,30 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
+ if (!input_in_mkl_format && format == FORMAT_NHWC)
+ dnnLayoutDelete_F32(mkl_lt_trans_input);
+
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
const Tensor& out_backprop = MklGetInput(context, 2);
- void* mkl_buf_out_backprop = const_cast<void*>(
- static_cast<const void*>(out_backprop.flat<T>().data()));
+ mkl_buf_out_backprop = const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data()));
+
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
prim_conv_bwdfilter,
dnnResourceDiffDst),
E_SUCCESS);
mkl_convert_out_backprop =
- !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop);
+ !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop,
+ lt_out_backprop);
if (mkl_convert_out_backprop) {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
- lt_out_backprop,
- mkl_lt_internal_out_backprop),
+ lt_out_backprop, mkl_lt_internal_out_backprop),
E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor,
- lt_out_backprop, &mkl_buf_convert_out_backprop);
+ mkl_lt_internal_out_backprop, &mkl_buf_convert_out_backprop);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
mkl_buf_out_backprop,
mkl_buf_convert_out_backprop),