aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework/gradient_checker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/framework/gradient_checker.cc')
-rw-r--r--tensorflow/cc/framework/gradient_checker.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index de2645cb44..e9f9c59e3a 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
const int64 y_size = y_shapes[y_idx].num_elements();
- const Y_T scale = Y_T{2 * delta};
+ const Y_T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
for (int c = 0; c < y_size; ++c) {
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
@@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
- *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
+ auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c));
+ // Treat any NaN as max_error and immediately return.
+ // (Note that std::max may ignore NaN arguments.)
+ if (std::isnan(cur_error)) {
+ *max_error = cur_error;
+ return Status::OK();
+ }
+ *max_error = std::max(*max_error, cur_error);
}
}
}
@@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
+INSTANTIATE_GRAD_ERR_TYPE(double, float, double);
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);