diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/array_cwise.cpp | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 6910f0e1f..27702c19d 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -9,6 +9,62 @@ #include "main.h" + +// Test the corner cases of pow(x, y) for real types. +template<typename Scalar> +void pow_test() { + const Scalar zero = Scalar(0); + const Scalar one = Scalar(1); + const Scalar sqrt_half = Scalar(std::sqrt(0.5)); + const Scalar sqrt2 = Scalar(std::sqrt(2)); + const Scalar inf = std::numeric_limits<Scalar>::infinity(); + const Scalar nan = std::numeric_limits<Scalar>::quiet_NaN(); + const static Scalar abs_vals[] = {zero, sqrt_half, one, sqrt2, inf, nan}; + const int abs_cases = 6; + const int num_cases = 2*abs_cases * 2*abs_cases; + // Repeat the same value to make sure we hit the vectorized path. + const int num_repeats = 32; + Array<Scalar, Dynamic, Dynamic> x(num_repeats, num_cases); + Array<Scalar, Dynamic, Dynamic> y(num_repeats, num_cases); + Array<Scalar, Dynamic, Dynamic> expected(num_repeats, num_cases); + int count = 0; + for (int i = 0; i < abs_cases; ++i) { + const Scalar abs_x = abs_vals[i]; + for (int sign_x = 0; sign_x < 2; ++sign_x) { + Scalar x_case = sign_x == 0 ? -abs_x : abs_x; + for (int j = 0; j < abs_cases; ++j) { + const Scalar abs_y = abs_vals[j]; + for (int sign_y = 0; sign_y < 2; ++sign_y) { + Scalar y_case = sign_y == 0 ? -abs_y : abs_y; + for (int repeat = 0; repeat < num_repeats; ++repeat) { + x(repeat, count) = x_case; + y(repeat, count) = y_case; + expected(repeat, count) = numext::pow(x_case, y_case); + } + ++count; + } + } + } + } + + Array<Scalar, Dynamic, Dynamic> actual = x.pow(y); + const Scalar tol = test_precision<Scalar>(); + bool all_pass = true; + for (int i = 0; i < 1; ++i) { + for (int j = 0; j < num_cases; ++j) { + Scalar a = actual(i, j); + Scalar e = expected(i, j); + bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((std::isnan)(a) && (std::isnan)(e)); + all_pass &= !fail; + if (fail) { + std::cout << "pow(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl; + } + } + } + VERIFY(all_pass); +} + + template<typename ArrayType> void array(const ArrayType& m) { typedef typename ArrayType::Scalar Scalar; @@ -371,6 +427,8 @@ template<typename ArrayType> void array_real(const ArrayType& m) VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt()); + VERIFY_IS_APPROX(m1.pow(RealScalar(-2)), m1.square().inverse()); + pow_test<Scalar>(); VERIFY_IS_APPROX(log10(m3), log(m3)/log(10)); VERIFY_IS_APPROX(log2(m3), log(m3)/log(2)); |