diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 409e3c5111..024158e709 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -613,6 +613,16 @@ def _AtanGrad(op, grad): return grad * inv +@ops.RegisterGradient("Atan2") +def _Atan2Grad(op, grad): + """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" + y = op.inputs[0] + x = op.inputs[1] + with ops.control_dependencies([grad.op]): + grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) + return x * grad_inv, -y * grad_inv + + @ops.RegisterGradient("AddN") def _AddNGrad(op, grad): """Copies the gradient to all inputs.""" |