aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index f0c6bd532f..2a7a2fd51f 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,6 +972,24 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
+@ops.RegisterGradient("UnsafeDiv")
+def _UnsafeDivGrad(op, grad):
+ """UnsafeDiv op gradient."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ x = math_ops.conj(x)
+ y = math_ops.conj(y)
+ return (array_ops.reshape(
+ math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
+ array_ops.reshape(
+ math_ops.reduce_sum(
+ grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
+ ry), sy))
+
+
@ops.RegisterGradient("Pow")
def _PowGrad(op, grad):
"""Returns grad * (y*x^(y-1), z*log(x))."""