diff options
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 1fd69ae717..3502f11892 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -613,16 +613,48 @@ def _MulGrad(op, grad): @ops.RegisterGradient("Div") def _DivGrad(op, grad): + """The gradient for the Div operator.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: disable=protected-access + # pylint: disable=protected-access + rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + # pylint: enable=protected-access + x = math_ops.conj(x) + y = math_ops.conj(y) + return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx), + array_ops.reshape(math_ops.reduce_sum( + grad * math_ops.div(-x, math_ops.square(y)), ry), sy)) + + +@ops.RegisterGradient("FloorDiv") +def _FloorDivGrad(_, unused_grad): + """The gradient for the FloorDiv operator.""" + return None, None + + +@ops.RegisterGradient("TruncateDiv") +def _TruncateDivGrad(_, unused_grad): + return None, None + + +@ops.RegisterGradient("RealDiv") +def _RealDivGrad(op, grad): + """RealDiv op gradient.""" + x = op.inputs[0] + y = op.inputs[1] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + # pylint: disable=protected-access + rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + # pylint: enable=protected-access x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx), - array_ops.reshape(math_ops.reduce_sum(grad * - (-x / math_ops.square(y)), ry), sy)) + return (array_ops.reshape(math_ops.reduce_sum( + math_ops.realdiv(grad, y), rx), sx), + array_ops.reshape(math_ops.reduce_sum( + grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy)) @ops.RegisterGradient("Pow") |