aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_avgpooling_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_avgpooling_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc31
1 files changed, 26 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index a7c569ee05..d545d34fdf 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -24,7 +24,7 @@
#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::algorithm;
using mkldnn::engine;
@@ -40,8 +40,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-// For now, MKL-ML is default. So making MKL-DNN not a default choice.
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklAvgPoolingOp : public OpKernel {
@@ -429,7 +428,7 @@ class MklAvgPoolingGradOp : public OpKernel {
TensorFormat data_format_;
}; // MklAvgPoolingGradOp
-#else // INTEL_MKL_DNN is defined
+#else
template <typename Device, typename T>
class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
@@ -466,6 +465,28 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
+ // If input is an empty tensor, allocate an empty output tensor and return
+ if (input_tensor.NumElements() == 0) {
+ MklDnnShape output_mkl_shape;
+ output_mkl_shape.SetMklTensor(false);
+ TensorShape output_tf_shape;
+ if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
+ } else {
+ memory::dims output_dims_NHWC_order;
+ output_dims_NHWC_order = {pool_params.tensor_in_batch,
+ static_cast<int>(pool_params.out_height),
+ static_cast<int>(pool_params.out_width),
+ pool_params.out_depth};
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
+ }
+ const int kOutputIndex = 0;
+ AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
+ output_tf_shape, output_mkl_shape);
+ CHECK_NOTNULL(output_tensor);
+ return;
+ }
+
// If input is in Mkl layout, then just get the memory format from it
// directly, instead of using input data_format to AvgPool.
if (dnn_shape_input.IsMklTensor()) {
@@ -678,7 +699,7 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklAvgPoolingGradOp
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
.Device(DEVICE_CPU)