aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-11 14:20:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-11 14:20:53 -0700
commitd2981e0f61f5da46b2ab90a0779842565db43519 (patch)
treed7dd4878941779ce01670e272d75b91924ae142c /tensorflow/cc
parentebfa84016deceb3a1d01aabed292d369e5e872e4 (diff)
parent7613b773e03987c89fe5e5883c411588bce59673 (diff)
Merge pull request #19105 from facaiy:ENH/unsafe_div
PiperOrigin-RevId: 208352779
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/BUILD2
-rw-r--r--tensorflow/cc/gradients/math_grad.cc16
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc33
3 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 588a45ea43..f4be60a183 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -379,9 +379,11 @@ tf_cc_test(
srcs = ["gradients/math_grad_test.cc"],
deps = [
":cc_ops",
+ ":client_session",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
+ ":gradients",
":math_grad",
":testutil",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index d95dd879b4..5dcf00857d 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
+Status UnsafeDivGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto x_1 = ConjugateHelper(scope, op.input(0));
+ auto x_2 = ConjugateHelper(scope, op.input(1));
+ // y = x_1 / x_2
+ // dy/dx_1 = 1/x_2
+ // dy/dx_2 = -x_1/x_2^2
+ auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
+ auto gx_2 =
+ Mul(scope, grad_inputs[0],
+ UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
+}
+REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index c6c9262786..88aef1fab4 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h"
+#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -46,6 +48,7 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
+using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@@ -851,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
+TEST_F(NaryGradTest, UnsafeDiv) {
+ {
+ TensorShape x_shape({3, 2, 5});
+ const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
+ // division errors in the numeric estimator used by the gradient checker.
+ const auto y = UnsafeDiv(
+ scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
+ RunTest({x}, {x_shape}, {y}, {x_shape});
+ }
+ {
+ // Return 0 gradient (rather than NaN) for division by zero.
+ const auto x = Placeholder(scope_, DT_FLOAT);
+ const auto zero = Const<float>(scope_, 0.0);
+ const auto y = UnsafeDiv(scope_, x, zero);
+
+ std::vector<Output> grad_outputs;
+ TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
+ ClientSession session(scope_);
+ std::vector<Tensor> grad_result;
+ TF_EXPECT_OK(
+ session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
+ EXPECT_EQ(grad_result.size(), 1);
+ EXPECT_EQ(grad_result[0].NumElements(), 3);
+ EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
+ EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
+ EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
+ }
+}
+
TEST_F(NaryGradTest, SquaredDifference) {
TensorShape x1_shape({3, 2, 5});
TensorShape x2_shape({2, 5});