diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index b6883dbaa2..c6456bd5c3 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -390,6 +390,29 @@ class MklConv2DBackpropCommonOp : public OpKernel { TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx); + // Corner cases: output with 0 elements and 0 batch size. + Tensor* output_tensor = nullptr; + if (input_tf_shape.num_elements() == 0 || + filter_tf_shape.num_elements() == 0 || + outbprop_tf_shape.num_elements() == 0) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape = GetOutputTfShape(input_tf_shape, + filter_tf_shape, + outbprop_tf_shape); + const int kOutputIdx = 0; + AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor, + output_tf_shape, output_mkl_shape); + CHECK_NOTNULL(output_tensor); + + // if output tensor has more than 0 elements, we need to 0 them out. + for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) { + output_tensor->flat<T>().data()[i] = 0; + } + + return; + } + // By default, all dims are in MKL order. Only dims in TF order // are those with prefix tf_order. memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims; @@ -471,7 +494,6 @@ class MklConv2DBackpropCommonOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Operator-specific call to create and execute primitive. - Tensor* output_tensor = nullptr; CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter, &outbackprop, &output, &output_tensor, strides, padding_l, padding_r, @@ -507,6 +529,11 @@ class MklConv2DBackpropCommonOp : public OpKernel { virtual TensorShape MakeFilterTfShape(OpKernelContext* context, const Tensor& filter_tensor) = 0; + /// Get the TensorFlow shape of output tensor. + virtual TensorShape GetOutputTfShape(const TensorShape& input_shape, + const TensorShape& filter_shape, + const TensorShape& outbprop_shape) = 0; + /// Get shape of output in MKL-DNN order. Computes shape of output from /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims). virtual |