diff options
author | 2017-10-13 15:22:19 -0700 | |
---|---|---|
committer | 2017-10-13 15:26:24 -0700 | |
commit | 7679a2ec746bec36191087feaf9ec8371180669c (patch) | |
tree | defd282cd8f64c5e7dca046d2bab54bd0031c416 /tensorflow/core/kernels/conv_grad_input_ops.cc | |
parent | 0bbdeaf45e07e1f5fb5e15961104e348e3ad8777 (diff) |
Fix crash if tf.nn.conv2d_backprop_filter or tf.nn.conv2d_backprop_input is run with empty filter or input respectively. Resolves #13643.
PiperOrigin-RevId: 172153646
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_input_ops.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 0b2d01afa9..d28f6b4d10 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -225,6 +225,11 @@ class Conv2DFastBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); + // If there is nothing to compute, return. + if (input_shape.num_elements() == 0) { + return; + } + #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD int64 pad_top, pad_bottom; int64 pad_left, pad_right; @@ -318,6 +323,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); + // If there is nothing to compute, return. + if (input_shape.num_elements() == 0) { + return; + } + // TODO(andydavis) Consider moving code shared with // Conv2DCustomBackpropFilterOp into a shared helper function. #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD @@ -603,6 +613,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); + // If there is nothing to compute, return. + if (input_shape.num_elements() == 0) { + return; + } + // For now we take the stride from the second and third dimensions only (we // do not support striding on the batch or depth dimension). const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); |