aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_maxpooling_op.cc
diff options
context:
space:
mode:
authorGravatar Li, Yiqiang <yiqiang.li@intel.com>2018-07-18 19:59:01 +0800
committerGravatar Li, Yiqiang <yiqiang.li@intel.com>2018-07-18 19:59:01 +0800
commitd490493cd4848422e5480e8a30a0a88af07641ad (patch)
tree571dc4b037be9db4cbb3609904992ba422f3cfa0 /tensorflow/core/kernels/mkl_maxpooling_op.cc
parent238eb51aeeeda23b3d2fe64c319b93f43ae4fd01 (diff)
Code refactoring per comments.
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc45
1 files changed, 9 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index 7fba7f6bb1..657f007a2e 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -537,23 +537,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// If input is an empty tensor, allocate an empty output tensor and return
if (input_tensor.NumElements() == 0) {
- MklDnnShape output_mkl_shape;
- output_mkl_shape.SetMklTensor(false);
- TensorShape output_tf_shape;
- if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
- output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
- } else {
- memory::dims output_dims_NHWC_order;
- output_dims_NHWC_order = {pool_params.tensor_in_batch,
- static_cast<int>(pool_params.out_height),
- static_cast<int>(pool_params.out_width),
- pool_params.out_depth};
- output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
- }
const int kOutputIndex = 0;
- AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
- output_tf_shape, output_mkl_shape);
- CHECK_NOTNULL(output_tensor);
+ this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
+ output_dims_mkl_order, &output_tensor);
return;
}
@@ -572,16 +558,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
: TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
this->data_format_tf_);
- memory::dims filter_dims = memory::dims({pool_params.window_rows,
- pool_params.window_cols});
- memory::dims strides = memory::dims({pool_params.row_stride,
- pool_params.col_stride});
- memory::dims padding_left = memory::dims(
- {static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)});
- memory::dims padding_right = memory::dims(
- {static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)});
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
// Get a pooling op from the cached pool
MklPoolingFwdPrimitive<T> *pooling_fwd = nullptr;
@@ -689,16 +668,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
this->InitMklPoolParameters(context, &pool_params,
orig_input_mkl_shape, orig_input_shape);
- memory::dims filter_dims = memory::dims(
- {pool_params.window_rows, pool_params.window_cols});
- memory::dims strides = memory::dims(
- {pool_params.row_stride, pool_params.col_stride});
- memory::dims padding_left = memory::dims(
- {static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)});
- memory::dims padding_right = memory::dims(
- {static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)});
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
+
memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),