diff options
Diffstat (limited to 'tensorflow/cc/gradients/math_grad.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d95dd879b4..5dcf00857d 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); +Status UnsafeDivGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto x_1 = ConjugateHelper(scope, op.input(0)); + auto x_2 = ConjugateHelper(scope, op.input(1)); + // y = x_1 / x_2 + // dy/dx_1 = 1/x_2 + // dy/dx_2 = -x_1/x_2^2 + auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2); + auto gx_2 = + Mul(scope, grad_inputs[0], + UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2)); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); +} +REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad); + Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { |