aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework/gradient_checker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/framework/gradient_checker.cc')
-rw-r--r--tensorflow/cc/framework/gradient_checker.cc9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index a1eb0d9d08..e9f9c59e3a 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
- *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
+ auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c));
+ // Treat any NaN as max_error and immediately return.
+ // (Note that std::max may ignore NaN arguments.)
+ if (std::isnan(cur_error)) {
+ *max_error = cur_error;
+ return Status::OK();
+ }
+ *max_error = std::max(*max_error, cur_error);
}
}
}