aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 409e3c5111..024158e709 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -613,6 +613,16 @@ def _AtanGrad(op, grad):
return grad * inv
+@ops.RegisterGradient("Atan2")
+def _Atan2Grad(op, grad):
+ """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
+ y = op.inputs[0]
+ x = op.inputs[1]
+ with ops.control_dependencies([grad.op]):
+ grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
+ return x * grad_inv, -y * grad_inv
+
+
@ops.RegisterGradient("AddN")
def _AddNGrad(op, grad):
"""Copies the gradient to all inputs."""