aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_grad.py')
-rw-r--r--tensorflow/python/ops/array_grad.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 87f8d14860..3c025881cb 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -460,11 +460,7 @@ def _GatherNdGrad(op, grad):
ref = op.inputs[0]
indices = op.inputs[1]
ref_shape = array_ops.shape(ref, out_type=indices.dtype)
- if indices.shape.ndims == 2 and indices.shape[-1].value == 1:
- ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
- ref_shape)
- else:
- ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
+ ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
return [ref_grad, None]