diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 59 |
1 files changed, 55 insertions, 4 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 9e7a922b2a..0620a3da2c 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -109,13 +109,41 @@ def _MeanGrad(op, grad): @ops.RegisterGradient("Prod") def _ProdGrad(op, grad): """Gradient for Prod.""" - # TODO(kearnes): this gives NaNs for 0s in the input tensor + # The gradient can be expressed by dividing the product by each entry of the + # input tensor, but this approach can't deal with zeros in the input. + # Here, we avoid this problem by composing the output as a product of two + # cumprod operations. + input_shape = array_ops.shape(op.inputs[0]) + + # Expand grad to full input shape output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) - grad = array_ops.reshape(grad * op.outputs[0], output_shape_kept_dims) - grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0]) - return grad, None + grad = array_ops.reshape(grad, output_shape_kept_dims) + grad = array_ops.tile(grad, tile_scaling) + + # Pack all reduced dimensions into a single one, so we can perform the + # cumprod ops. If the reduction dims list is empty, it defaults to float32, + # so we need to cast here. + reduced = math_ops.cast(op.inputs[1], dtypes.int32) + idx = math_ops.range(0, array_ops.rank(op.inputs[0])) + other, _ = array_ops.listdiff(idx, reduced) + perm = array_ops.concat(0, [reduced, other]) + reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) + other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) + permuted = array_ops.transpose(op.inputs[0], perm) + permuted_shape = array_ops.shape(permuted) + reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) + + # Calculate product, leaving out the current entry + left = math_ops.cumprod(reshaped, axis=0, exclusive=True) + right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) + y = array_ops.reshape(left * right, permuted_shape) + + # Invert the transpose and reshape operations. + # Make sure to set the statically known shape information through a reshape. + out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) + return array_ops.reshape(out, input_shape), None @ops.RegisterGradient("SegmentSum") @@ -839,3 +867,26 @@ def _CrossGrad(op, grad): u = op.inputs[0] v = op.inputs[1] return (math_ops.cross(v, grad), math_ops.cross(grad, u)) + + +@ops.RegisterGradient("Cumsum") +def _CumsumGrad(op, grad): + axis = op.inputs[1] + exclusive = op.get_attr("exclusive") + reverse = op.get_attr("reverse") + return [math_ops.cumsum(grad, axis, exclusive=exclusive, + reverse=not reverse), None] + + +@ops.RegisterGradient("Cumprod") +def _CumprodGrad(op, grad): + x = op.inputs[0] + axis = op.inputs[1] + exclusive = op.get_attr("exclusive") + reverse = op.get_attr("reverse") + + # TODO This fails when x contains 0 and should be fixed + prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) + out = math_ops.cumsum(prod * grad, axis, exclusive=exclusive, + reverse=not reverse) + return [out / x, None] |