aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/math_grad.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
1 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 5dcf00857d..1329b568ab 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const Scope& scope, const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
+Status DivNoNanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
- auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
- auto gx_2 =
- Mul(scope, grad_inputs[0],
- UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
-REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,