diff options
Diffstat (limited to 'tensorflow/python/ops/array_grad.py')
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 87f8d14860..3c025881cb 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -460,11 +460,7 @@ def _GatherNdGrad(op, grad): ref = op.inputs[0] indices = op.inputs[1] ref_shape = array_ops.shape(ref, out_type=indices.dtype) - if indices.shape.ndims == 2 and indices.shape[-1].value == 1: - ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), - ref_shape) - else: - ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) + ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None] |