aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-05-05 17:16:04 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-05-05 18:54:05 +0800
commit043b574402c58e1cf629242b3faad3ec071e5ce4 (patch)
tree253c660045ff8cca4cb2b855d32c784905d626da /tensorflow/cc
parent1bb62968ec5bd726f7cbc11a00b6001d64407ca9 (diff)
ENH: add gradient function
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc10
2 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 52c177212a..ea86fc0a7c 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -440,6 +440,21 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
+Status UnsafeDivGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto x_1 = ConjugateHelper(scope, op.input(0));
+ auto x_2 = ConjugateHelper(scope, op.input(1));
+ // y = x_1 / x_2
+ // dy/dx_1 = 1/x_2
+ // dy/dx_2 = -x_1/x_2^2
+ auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
+}
+REGISTER_GRADIENT_OP("UnsafeDiv", DivGrad);
+
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 1b4c7c2688..0cc398abcf 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -46,6 +46,7 @@ using ops::RealDiv;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
+using ops::UnsafeDiv;
using ops::Where3;
// TODO(andydavis) Test gradient function against numeric gradients output.
@@ -856,6 +857,15 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
+TEST_F(NaryGradTest, UnsafeDiv) {
+ TensorShape x_shape({3, 2, 5});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
+ // division errors in the numeric estimator used by the gradient checker.
+ auto y = UnsafeDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
+ RunTest({x}, {x_shape}, {y}, {x_shape});
+}
+
TEST_F(NaryGradTest, SquaredDifference) {
TensorShape x1_shape({3, 2, 5});
TensorShape x2_shape({2, 5});