diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 65 |
1 files changed, 16 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 3b23c72f0f..f81a448e51 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -206,15 +206,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Mkl needs the entities in its native format. // So create temporary tensors along with buffers to // convert the received entities. - Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor, - mkl_tmp_buf_trans_input; + Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor; // This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst - mkl_context.MklPrepareInputs(context, data_format_, - input_in_mkl_format, - out_backprop_in_mkl_format, - &mkl_tmp_input_buf_tensor, - &mkl_tmp_out_backprop_buf_tensor, - &mkl_tmp_buf_trans_input); + mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor, + &mkl_tmp_out_backprop_buf_tensor); // Final conv-grad-filter should be in TF layout. Tensor* grad_filter; @@ -312,58 +307,34 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Compare incoming tensor layouts with MKL preferred layouts and convert // data to the preferred layout if necessary - void MklPrepareInputs(OpKernelContext* context, TensorFormat format, - bool input_in_mkl_format, - bool out_backprop_in_mkl_format, + void MklPrepareInputs(OpKernelContext* context, Tensor* mkl_tmp_input_buf_tensor, - Tensor* mkl_tmp_out_backprop_buf_tensor, - Tensor* mkl_tmp_buf_trans_input) { + Tensor* mkl_tmp_out_backprop_buf_tensor) { bool mkl_convert_input, mkl_convert_out_backprop; dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop; - dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop, - mkl_lt_trans_input; + dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop; void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop; - void *mkl_buf_input, *mkl_buf_out_backprop; mkl_prim_convert_input = nullptr; mkl_prim_convert_out_backprop = nullptr; mkl_lt_internal_input = nullptr; mkl_lt_internal_out_backprop = nullptr; - mkl_lt_trans_input = nullptr; mkl_buf_convert_input = nullptr; mkl_buf_convert_out_backprop = nullptr; - mkl_buf_input = nullptr; - mkl_buf_out_backprop = nullptr; // Compare with internal layouts and convert if needed const Tensor& input = MklGetInput(context, 0); - if (!input_in_mkl_format && format == FORMAT_NHWC){ - TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW, - in_sizes[MklDims::N], in_sizes[MklDims::H], - in_sizes[MklDims::W], in_sizes[MklDims::C]); - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input)); - MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input); - mkl_buf_input = const_cast<void*>(static_cast<const void*>( - mkl_tmp_buf_trans_input->flat<float>().data())); - size_t strides[4]; - GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes); - CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes, - strides), E_SUCCESS); - } - else { - mkl_buf_input = - const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); - mkl_lt_trans_input = lt_input; - } + void* mkl_buf_input = + const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( &mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc), E_SUCCESS); mkl_convert_input = - !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input); + !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); if (mkl_convert_input) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, - mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS); + CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, + mkl_lt_internal_input), + E_SUCCESS); AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, &mkl_buf_convert_input); CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, @@ -372,30 +343,26 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { dnnDelete_F32(mkl_prim_convert_input); } dnnLayoutDelete_F32(mkl_lt_internal_input); - if (!input_in_mkl_format && format == FORMAT_NHWC) - dnnLayoutDelete_F32(mkl_lt_trans_input); - conv_res[dnnResourceSrc] = (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; const Tensor& out_backprop = MklGetInput(context, 2); - mkl_buf_out_backprop = const_cast<void*>( - static_cast<const void*>(out_backprop.flat<T>().data())); + void* mkl_buf_out_backprop = const_cast<void*>(static_cast<const void*>( + out_backprop.flat<T>().data())); CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop, prim_conv_bwdfilter, dnnResourceDiffDst), E_SUCCESS); mkl_convert_out_backprop = - !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, - lt_out_backprop); + !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop); if (mkl_convert_out_backprop) { CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop, lt_out_backprop, mkl_lt_internal_out_backprop), E_SUCCESS); AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor, - mkl_lt_internal_out_backprop, &mkl_buf_convert_out_backprop); + lt_out_backprop, &mkl_buf_convert_out_backprop); CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop, mkl_buf_out_backprop, mkl_buf_convert_out_backprop), |