aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 51e42d9b45..7483d70b69 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -110,6 +110,8 @@ def _ProdGrad(op, grad):
# cumprod operations.
input_shape = array_ops.shape(op.inputs[0])
+ # Reshape reduction indices for the case where the parameter is a scalar
+ reduction_indices = array_ops.reshape(op.inputs[1], [-1])
# Expand grad to full input shape
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
@@ -122,7 +124,7 @@ def _ProdGrad(op, grad):
# so we need to cast here. We put all the shape-related ops on CPU to avoid
# copying back and forth, and since listdiff is CPU only.
with ops.device("/cpu:0"):
- reduced = math_ops.cast(op.inputs[1], dtypes.int32)
+ reduced = math_ops.cast(reduction_indices, dtypes.int32)
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
other, _ = array_ops.listdiff(idx, reduced)
perm = array_ops.concat(0, [reduced, other])