aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py40
1 files changed, 36 insertions, 4 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 1fd69ae717..3502f11892 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -613,16 +613,48 @@ def _MulGrad(op, grad):
@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
+ """The gradient for the Div operator."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
- rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ # pylint: enable=protected-access
+ x = math_ops.conj(x)
+ y = math_ops.conj(y)
+ return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(
+ grad * math_ops.div(-x, math_ops.square(y)), ry), sy))
+
+
+@ops.RegisterGradient("FloorDiv")
+def _FloorDivGrad(_, unused_grad):
+ """The gradient for the FloorDiv operator."""
+ return None, None
+
+
+@ops.RegisterGradient("TruncateDiv")
+def _TruncateDivGrad(_, unused_grad):
+ return None, None
+
+
+@ops.RegisterGradient("RealDiv")
+def _RealDivGrad(op, grad):
+ """RealDiv op gradient."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ # pylint: disable=protected-access
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ # pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
- return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
- array_ops.reshape(math_ops.reduce_sum(grad *
- (-x / math_ops.square(y)), ry), sy))
+ return (array_ops.reshape(math_ops.reduce_sum(
+ math_ops.realdiv(grad, y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(
+ grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy))
@ops.RegisterGradient("Pow")