From fee3f260d6eba1aec57df09045459790dcae686f Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Mon, 30 Jul 2018 13:17:21 +0800 Subject: TST: add test case, division by zero --- tensorflow/cc/gradients/math_grad_test.cc | 33 +++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index b76478d78b..27021e28f8 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -857,12 +859,31 @@ TEST_F(NaryGradTest, RealDiv) { } TEST_F(NaryGradTest, UnsafeDiv) { - TensorShape x_shape({3, 2, 5}); - auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); - // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large - // division errors in the numeric estimator used by the gradient checker. - auto y = UnsafeDiv(scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); - RunTest({x}, {x_shape}, {y}, {x_shape}); + { + TensorShape x_shape({3, 2, 5}); + const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large + // division errors in the numeric estimator used by the gradient checker. + const auto y = UnsafeDiv(scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); + RunTest({x}, {x_shape}, {y}, {x_shape}); + } + { + // Return 0 gradient (rather than NaN) for division by zero. + const auto x = Placeholder(scope_, DT_FLOAT); + const auto zero = Const(scope_, 0.0); + const auto y = UnsafeDiv(scope_, x, zero); + + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs)); + ClientSession session(scope_); + std::vector grad_result; + TF_EXPECT_OK(session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 3); + EXPECT_EQ(grad_result[0].flat()(0), 0.0f); + EXPECT_EQ(grad_result[0].flat()(1), 0.0f); + EXPECT_EQ(grad_result[0].flat()(2), 0.0f); + } } TEST_F(NaryGradTest, SquaredDifference) { -- cgit v1.2.3