diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_pooling_ops_common.h | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index 279167aba2..c0dfed7d7d 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -199,13 +199,15 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { CHECK_NOTNULL(pool_params); CHECK_NOTNULL(dnn_data_input); TensorShape input_tensor_shape = input_tensor.shape(); - memory::desc input_md = + if (input_tensor.NumElements() != 0) { + memory::desc input_md = input_mkl_shape.IsMklTensor() ? input_mkl_shape.GetMklLayout() : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, this->data_format_tf_), MklDnnType<T>(), this->data_format_mkldnn_); - dnn_data_input->SetUsrMem(input_md, &input_tensor); + dnn_data_input->SetUsrMem(input_md, &input_tensor); + } this->InitMklPoolParameters(context, pool_params, input_mkl_shape, input_tensor_shape); } |