aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_grad_input_ops.cc
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2017-10-13 15:22:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 15:26:24 -0700
commit7679a2ec746bec36191087feaf9ec8371180669c (patch)
treedefd282cd8f64c5e7dca046d2bab54bd0031c416 /tensorflow/core/kernels/conv_grad_input_ops.cc
parent0bbdeaf45e07e1f5fb5e15961104e348e3ad8777 (diff)
Fix crash if tf.nn.conv2d_backprop_filter or tf.nn.conv2d_backprop_input is run with empty filter or input respectively. Resolves #13643.
PiperOrigin-RevId: 172153646
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 0b2d01afa9..d28f6b4d10 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -225,6 +225,11 @@ class Conv2DFastBackpropInputOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
+ // If there is nothing to compute, return.
+ if (input_shape.num_elements() == 0) {
+ return;
+ }
+
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
int64 pad_top, pad_bottom;
int64 pad_left, pad_right;
@@ -318,6 +323,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
+ // If there is nothing to compute, return.
+ if (input_shape.num_elements() == 0) {
+ return;
+ }
+
// TODO(andydavis) Consider moving code shared with
// Conv2DCustomBackpropFilterOp into a shared helper function.
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
@@ -603,6 +613,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
+ // If there is nothing to compute, return.
+ if (input_shape.num_elements() == 0) {
+ return;
+ }
+
// For now we take the stride from the second and third dimensions only (we
// do not support striding on the batch or depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');