aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/array_cwise.cpp
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-05 16:58:49 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-05 16:58:49 -0800
commit6e3b795f811e8f3bf75393f8274e558a40479cc9 (patch)
treea9571d702a68f681ac9f6a25e6b7b6d6057a6fab /test/array_cwise.cpp
parentabcde69a79c35c118e156964a1b6fb75f1ea2adb (diff)
Add more tests for pow and fix a corner case for huge exponent where the result is always zero or infinite unless x is one.
Diffstat (limited to 'test/array_cwise.cpp')
-rw-r--r--test/array_cwise.cpp25
1 files changed, 20 insertions, 5 deletions
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index fa087a060..a1529bc96 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -15,18 +15,31 @@ template<typename Scalar>
void pow_test() {
const Scalar zero = Scalar(0);
const Scalar one = Scalar(1);
+ const Scalar two = Scalar(2);
+ const Scalar three = Scalar(3);
const Scalar sqrt_half = Scalar(std::sqrt(0.5));
const Scalar sqrt2 = Scalar(std::sqrt(2));
const Scalar inf = std::numeric_limits<Scalar>::infinity();
const Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
- const static Scalar abs_vals[] = {zero, sqrt_half, one, sqrt2, inf, nan};
- const int abs_cases = 6;
+ const Scalar min = (std::numeric_limits<Scalar>::min)();
+ const Scalar max = (std::numeric_limits<Scalar>::max)();
+ const static Scalar abs_vals[] = {zero,
+ sqrt_half,
+ one,
+ sqrt2,
+ two,
+ three,
+ min,
+ max,
+ inf,
+ nan};
+
+ const int abs_cases = 10;
const int num_cases = 2*abs_cases * 2*abs_cases;
// Repeat the same value to make sure we hit the vectorized path.
const int num_repeats = 32;
Array<Scalar, Dynamic, Dynamic> x(num_repeats, num_cases);
Array<Scalar, Dynamic, Dynamic> y(num_repeats, num_cases);
- Array<Scalar, Dynamic, Dynamic> expected(num_repeats, num_cases);
int count = 0;
for (int i = 0; i < abs_cases; ++i) {
const Scalar abs_x = abs_vals[i];
@@ -39,7 +52,6 @@ void pow_test() {
for (int repeat = 0; repeat < num_repeats; ++repeat) {
x(repeat, count) = x_case;
y(repeat, count) = y_case;
- expected(repeat, count) = numext::pow(x_case, y_case);
}
++count;
}
@@ -52,8 +64,11 @@ void pow_test() {
bool all_pass = true;
for (int i = 0; i < 1; ++i) {
for (int j = 0; j < num_cases; ++j) {
+ // TODO(rmlarsen): Skip tests that trigger a known bug in pldexp for now.
+ if (std::abs(x(i,j)) == max || std::abs(x(i,j)) == min) continue;
+
+ Scalar e = numext::pow(x(i,j), y(i,j));
Scalar a = actual(i, j);
- Scalar e = expected(i, j);
bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e));
all_pass &= !fail;
if (fail) {