aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py59
1 files changed, 55 insertions, 4 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 9e7a922b2a..0620a3da2c 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -109,13 +109,41 @@ def _MeanGrad(op, grad):
@ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
"""Gradient for Prod."""
- # TODO(kearnes): this gives NaNs for 0s in the input tensor
+ # The gradient can be expressed by dividing the product by each entry of the
+ # input tensor, but this approach can't deal with zeros in the input.
+ # Here, we avoid this problem by composing the output as a product of two
+ # cumprod operations.
+
input_shape = array_ops.shape(op.inputs[0])
+
+ # Expand grad to full input shape
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
- grad = array_ops.reshape(grad * op.outputs[0], output_shape_kept_dims)
- grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0])
- return grad, None
+ grad = array_ops.reshape(grad, output_shape_kept_dims)
+ grad = array_ops.tile(grad, tile_scaling)
+
+ # Pack all reduced dimensions into a single one, so we can perform the
+ # cumprod ops. If the reduction dims list is empty, it defaults to float32,
+ # so we need to cast here.
+ reduced = math_ops.cast(op.inputs[1], dtypes.int32)
+ idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
+ other, _ = array_ops.listdiff(idx, reduced)
+ perm = array_ops.concat(0, [reduced, other])
+ reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
+ other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
+ permuted = array_ops.transpose(op.inputs[0], perm)
+ permuted_shape = array_ops.shape(permuted)
+ reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
+
+ # Calculate product, leaving out the current entry
+ left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
+ right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
+ y = array_ops.reshape(left * right, permuted_shape)
+
+ # Invert the transpose and reshape operations.
+ # Make sure to set the statically known shape information through a reshape.
+ out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
+ return array_ops.reshape(out, input_shape), None
@ops.RegisterGradient("SegmentSum")
@@ -839,3 +867,26 @@ def _CrossGrad(op, grad):
u = op.inputs[0]
v = op.inputs[1]
return (math_ops.cross(v, grad), math_ops.cross(grad, u))
+
+
+@ops.RegisterGradient("Cumsum")
+def _CumsumGrad(op, grad):
+ axis = op.inputs[1]
+ exclusive = op.get_attr("exclusive")
+ reverse = op.get_attr("reverse")
+ return [math_ops.cumsum(grad, axis, exclusive=exclusive,
+ reverse=not reverse), None]
+
+
+@ops.RegisterGradient("Cumprod")
+def _CumprodGrad(op, grad):
+ x = op.inputs[0]
+ axis = op.inputs[1]
+ exclusive = op.get_attr("exclusive")
+ reverse = op.get_attr("reverse")
+
+ # TODO This fails when x contains 0 and should be fixed
+ prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
+ out = math_ops.cumsum(prod * grad, axis, exclusive=exclusive,
+ reverse=not reverse)
+ return [out / x, None]