aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h29
1 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index b6883dbaa2..c6456bd5c3 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -390,6 +390,29 @@ class MklConv2DBackpropCommonOp : public OpKernel {
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx);
+ // Corner cases: output with 0 elements and 0 batch size.
+ Tensor* output_tensor = nullptr;
+ if (input_tf_shape.num_elements() == 0 ||
+ filter_tf_shape.num_elements() == 0 ||
+ outbprop_tf_shape.num_elements() == 0) {
+ MklDnnShape output_mkl_shape;
+ output_mkl_shape.SetMklTensor(false);
+ TensorShape output_tf_shape = GetOutputTfShape(input_tf_shape,
+ filter_tf_shape,
+ outbprop_tf_shape);
+ const int kOutputIdx = 0;
+ AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor,
+ output_tf_shape, output_mkl_shape);
+ CHECK_NOTNULL(output_tensor);
+
+ // if output tensor has more than 0 elements, we need to 0 them out.
+ for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) {
+ output_tensor->flat<T>().data()[i] = 0;
+ }
+
+ return;
+ }
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims;
@@ -471,7 +494,6 @@ class MklConv2DBackpropCommonOp : public OpKernel {
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Operator-specific call to create and execute primitive.
- Tensor* output_tensor = nullptr;
CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter,
&outbackprop, &output, &output_tensor,
strides, padding_l, padding_r,
@@ -507,6 +529,11 @@ class MklConv2DBackpropCommonOp : public OpKernel {
virtual TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) = 0;
+ /// Get the TensorFlow shape of output tensor.
+ virtual TensorShape GetOutputTfShape(const TensorShape& input_shape,
+ const TensorShape& filter_shape,
+ const TensorShape& outbprop_shape) = 0;
+
/// Get shape of output in MKL-DNN order. Computes shape of output from
/// input shape (fwd_input_dims) and filter shape (fwd_filter_dims).
virtual