aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_grad.py')
-rw-r--r--tensorflow/python/ops/array_grad.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3678bd4c1f..fe459a96b9 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -568,7 +568,6 @@ ops.NotDifferentiable("Size")
@ops.RegisterGradient("Tile")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
- assert isinstance(grad, ops.Tensor)
input_shape = array_ops.shape(op.inputs[0])
# We interleave multiples and input_shape to get split_shape,
# reshape grad to split_shape, and reduce along all even
@@ -581,6 +580,13 @@ def _TileGrad(op, grad):
split_shape = array_ops.reshape(
array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
+ # Sum reduces grad along the first dimension for IndexedSlices
+ if isinstance(grad, ops.IndexedSlices):
+ grad = math_ops.unsorted_segment_sum(
+ grad.values,
+ math_ops.mod(grad.indices, input_shape[0]),
+ input_shape[0])
+ split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if not context.executing_eagerly():