aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc65
1 files changed, 16 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 3b23c72f0f..f81a448e51 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -206,15 +206,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Mkl needs the entities in its native format.
// So create temporary tensors along with buffers to
// convert the received entities.
- Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor,
- mkl_tmp_buf_trans_input;
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor;
// This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst
- mkl_context.MklPrepareInputs(context, data_format_,
- input_in_mkl_format,
- out_backprop_in_mkl_format,
- &mkl_tmp_input_buf_tensor,
- &mkl_tmp_out_backprop_buf_tensor,
- &mkl_tmp_buf_trans_input);
+ mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_out_backprop_buf_tensor);
// Final conv-grad-filter should be in TF layout.
Tensor* grad_filter;
@@ -312,58 +307,34 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Compare incoming tensor layouts with MKL preferred layouts and convert
// data to the preferred layout if necessary
- void MklPrepareInputs(OpKernelContext* context, TensorFormat format,
- bool input_in_mkl_format,
- bool out_backprop_in_mkl_format,
+ void MklPrepareInputs(OpKernelContext* context,
Tensor* mkl_tmp_input_buf_tensor,
- Tensor* mkl_tmp_out_backprop_buf_tensor,
- Tensor* mkl_tmp_buf_trans_input) {
+ Tensor* mkl_tmp_out_backprop_buf_tensor) {
bool mkl_convert_input, mkl_convert_out_backprop;
dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop;
- dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop,
- mkl_lt_trans_input;
+ dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop;
void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop;
- void *mkl_buf_input, *mkl_buf_out_backprop;
mkl_prim_convert_input = nullptr;
mkl_prim_convert_out_backprop = nullptr;
mkl_lt_internal_input = nullptr;
mkl_lt_internal_out_backprop = nullptr;
- mkl_lt_trans_input = nullptr;
mkl_buf_convert_input = nullptr;
mkl_buf_convert_out_backprop = nullptr;
- mkl_buf_input = nullptr;
- mkl_buf_out_backprop = nullptr;
// Compare with internal layouts and convert if needed
const Tensor& input = MklGetInput(context, 0);
- if (!input_in_mkl_format && format == FORMAT_NHWC){
- TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW,
- in_sizes[MklDims::N], in_sizes[MklDims::H],
- in_sizes[MklDims::W], in_sizes[MklDims::C]);
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<float>::value, nchw_shape, mkl_tmp_buf_trans_input));
- MklNHWCToNCHW(input, &mkl_tmp_buf_trans_input);
- mkl_buf_input = const_cast<void*>(static_cast<const void*>(
- mkl_tmp_buf_trans_input->flat<float>().data()));
- size_t strides[4];
- GetStridesFromSizes(FORMAT_NCHW, strides, in_sizes);
- CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_trans_input, in_dims, in_sizes,
- strides), E_SUCCESS);
- }
- else {
- mkl_buf_input =
- const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
- mkl_lt_trans_input = lt_input;
- }
+ void* mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
&mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc),
E_SUCCESS);
mkl_convert_input =
- !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_trans_input);
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
if (mkl_convert_input) {
- CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
- mkl_lt_trans_input, mkl_lt_internal_input), E_SUCCESS);
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
+ mkl_lt_internal_input),
+ E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&mkl_buf_convert_input);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
@@ -372,30 +343,26 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
dnnDelete_F32(mkl_prim_convert_input);
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
- if (!input_in_mkl_format && format == FORMAT_NHWC)
- dnnLayoutDelete_F32(mkl_lt_trans_input);
-
conv_res[dnnResourceSrc] =
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
const Tensor& out_backprop = MklGetInput(context, 2);
- mkl_buf_out_backprop = const_cast<void*>(
- static_cast<const void*>(out_backprop.flat<T>().data()));
+ void* mkl_buf_out_backprop = const_cast<void*>(static_cast<const void*>(
+ out_backprop.flat<T>().data()));
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
prim_conv_bwdfilter,
dnnResourceDiffDst),
E_SUCCESS);
mkl_convert_out_backprop =
- !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop,
- lt_out_backprop);
+ !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop);
if (mkl_convert_out_backprop) {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
lt_out_backprop, mkl_lt_internal_out_backprop),
E_SUCCESS);
AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor,
- mkl_lt_internal_out_backprop, &mkl_buf_convert_out_backprop);
+ lt_out_backprop, &mkl_buf_convert_out_backprop);
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
mkl_buf_out_backprop,
mkl_buf_convert_out_backprop),