diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index f0c6bd532f..2a7a2fd51f 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -972,6 +972,24 @@ def _RealDivGrad(op, grad): grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) +@ops.RegisterGradient("UnsafeDiv") +def _UnsafeDivGrad(op, grad): + """UnsafeDiv op gradient.""" + x = op.inputs[0] + y = op.inputs[1] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + x = math_ops.conj(x) + y = math_ops.conj(y) + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y), + ry), sy)) + + @ops.RegisterGradient("Pow") def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" |