aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_pooling_ops_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h')
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index 279167aba2..c0dfed7d7d 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -199,13 +199,15 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
CHECK_NOTNULL(pool_params);
CHECK_NOTNULL(dnn_data_input);
TensorShape input_tensor_shape = input_tensor.shape();
- memory::desc input_md =
+ if (input_tensor.NumElements() != 0) {
+ memory::desc input_md =
input_mkl_shape.IsMklTensor()
? input_mkl_shape.GetMklLayout()
: memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
- dnn_data_input->SetUsrMem(input_md, &input_tensor);
+ dnn_data_input->SetUsrMem(input_md, &input_tensor);
+ }
this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
input_tensor_shape);
}