aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/packetmath.cpp
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-07 09:39:05 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-01-08 01:17:19 +0000
commitf149e0ebc3d3d5ca63234e58ca72690caf07e3b5 (patch)
tree8c5431fd057c96b8231be84b2908d130b49d61ec /test/packetmath.cpp
parent8d9cfba799ce3462c12568a36392e0abf36fc62d (diff)
Fix MSVC complex sqrt and packetmath test.
MSVC incorrectly handles `inf` cases for `std::sqrt<std::complex<T>>`. Here we replace it with a custom version (currently used on GPU). Also fixed the `packetmath` test, which previously skipped several corner cases since `CHECK_CWISE1` only tests the first `PacketSize` elements.
Diffstat (limited to 'test/packetmath.cpp')
-rw-r--r--test/packetmath.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index f19d72502..ab9bec183 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -933,7 +933,7 @@ void packetmath_complex() {
for (int i = 0; i < size; ++i) {
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
}
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size);
// Test misc. corner cases.
const RealScalar zero = RealScalar(0);
@@ -944,32 +944,32 @@ void packetmath_complex() {
data1[1] = Scalar(-zero, zero);
data1[2] = Scalar(one, zero);
data1[3] = Scalar(zero, one);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
data1[0] = Scalar(-one, zero);
data1[1] = Scalar(zero, -one);
data1[2] = Scalar(one, one);
data1[3] = Scalar(-one, -one);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
data1[0] = Scalar(inf, zero);
data1[1] = Scalar(zero, inf);
data1[2] = Scalar(-inf, zero);
data1[3] = Scalar(zero, -inf);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
data1[0] = Scalar(inf, inf);
data1[1] = Scalar(-inf, inf);
data1[2] = Scalar(inf, -inf);
data1[3] = Scalar(-inf, -inf);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
data1[0] = Scalar(nan, zero);
data1[1] = Scalar(zero, nan);
data1[2] = Scalar(nan, one);
data1[3] = Scalar(one, nan);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
data1[0] = Scalar(nan, nan);
data1[1] = Scalar(inf, nan);
data1[2] = Scalar(nan, inf);
data1[3] = Scalar(-inf, nan);
- CHECK_CWISE1(numext::sqrt, internal::psqrt);
+ CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
}
}