diff options
Diffstat (limited to 'tensorflow/cc/framework/gradient_checker.cc')
-rw-r--r-- | tensorflow/cc/framework/gradient_checker.cc | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index a1eb0d9d08..e9f9c59e3a 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, auto jac_n = jacobian_ns[i].matrix<JAC_T>(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { - *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); + auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c)); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } } } |