diff options
author | 2018-07-18 19:59:01 +0800 | |
---|---|---|
committer | 2018-07-18 19:59:01 +0800 | |
commit | d490493cd4848422e5480e8a30a0a88af07641ad (patch) | |
tree | 571dc4b037be9db4cbb3609904992ba422f3cfa0 /tensorflow/core/kernels/mkl_maxpooling_op.cc | |
parent | 238eb51aeeeda23b3d2fe64c319b93f43ae4fd01 (diff) |
Code refactoring per comments.
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_maxpooling_op.cc | 45 |
1 files changed, 9 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index 7fba7f6bb1..657f007a2e 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -537,23 +537,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { // If input is an empty tensor, allocate an empty output tensor and return if (input_tensor.NumElements() == 0) { - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape; - if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { - output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); - } else { - memory::dims output_dims_NHWC_order; - output_dims_NHWC_order = {pool_params.tensor_in_batch, - static_cast<int>(pool_params.out_height), - static_cast<int>(pool_params.out_width), - pool_params.out_depth}; - output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); - } const int kOutputIndex = 0; - AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, - output_tf_shape, output_mkl_shape); - CHECK_NOTNULL(output_tensor); + this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, + output_dims_mkl_order, &output_tensor); return; } @@ -572,16 +558,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), this->data_format_tf_); - memory::dims filter_dims = memory::dims({pool_params.window_rows, - pool_params.window_cols}); - memory::dims strides = memory::dims({pool_params.row_stride, - pool_params.col_stride}); - memory::dims padding_left = memory::dims( - {static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}); - memory::dims padding_right = memory::dims( - {static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}); + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); // Get a pooling op from the cached pool MklPoolingFwdPrimitive<T> *pooling_fwd = nullptr; @@ -689,16 +668,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape, orig_input_shape); - memory::dims filter_dims = memory::dims( - {pool_params.window_rows, pool_params.window_cols}); - memory::dims strides = memory::dims( - {pool_params.row_stride, pool_params.col_stride}); - memory::dims padding_left = memory::dims( - {static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}); - memory::dims padding_right = memory::dims( - {static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}); + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor() ? grad_mkl_shape.GetSizesAsMklDnnDims() : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), |