diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 51e42d9b45..7483d70b69 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -110,6 +110,8 @@ def _ProdGrad(op, grad): # cumprod operations. input_shape = array_ops.shape(op.inputs[0]) + # Reshape reduction indices for the case where the parameter is a scalar + reduction_indices = array_ops.reshape(op.inputs[1], [-1]) # Expand grad to full input shape output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) @@ -122,7 +124,7 @@ def _ProdGrad(op, grad): # so we need to cast here. We put all the shape-related ops on CPU to avoid # copying back and forth, and since listdiff is CPU only. with ops.device("/cpu:0"): - reduced = math_ops.cast(op.inputs[1], dtypes.int32) + reduced = math_ops.cast(reduction_indices, dtypes.int32) idx = math_ops.range(0, array_ops.rank(op.inputs[0])) other, _ = array_ops.listdiff(idx, reduced) perm = array_ops.concat(0, [reduced, other]) |