From f0e46ed5d41eeb450cbcbdb1ce3233d524ad3acd Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 22 Jan 2021 11:10:54 -0800 Subject: Fix pow and other cwise ops for half/bfloat16. The new `generic_pow` implementation was failing for half/bfloat16 since their construction from int/float is not `constexpr`. Modified in `GenericPacketMathFunctions` to remove `constexpr`. While adding tests for half/bfloat16, found other issues related to implicit conversions. Also needed to implement `numext::arg` for non-integer, non-complex, non-float/double/long double types. These seem to be implicitly converted to `std::complex`, which then fails for half/bfloat16. --- test/array_cwise.cpp | 55 ++++++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 25 deletions(-) (limited to 'test/array_cwise.cpp') diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 27702c19d..6ea504c09 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -329,7 +329,7 @@ template void array_real(const ArrayType& m) m3(rows, cols), m4 = m1; - m4 = (m4.abs()==Scalar(0)).select(1,m4); + m4 = (m4.abs()==Scalar(0)).select(Scalar(1),m4); Scalar s1 = internal::random(); @@ -358,7 +358,7 @@ template void array_real(const ArrayType& m) VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all()); VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all()); VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all()); - VERIFY_IS_APPROX(m1.inverse(), inverse(m1)); + VERIFY_IS_APPROX(m4.inverse(), inverse(m4)); VERIFY_IS_APPROX(m1.abs(), abs(m1)); VERIFY_IS_APPROX(m1.abs2(), abs2(m1)); VERIFY_IS_APPROX(m1.square(), square(m1)); @@ -367,11 +367,11 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.sign(), sign(m1)); VERIFY((m1.sqrt().sign().isNaN() == (Eigen::isnan)(sign(sqrt(m1)))).all()); - // avoid NaNs with abs() so verification doesn't fail - m3 = m1.abs(); - VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m1))); - VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m1))); - VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m1))); + // avoid inf and NaNs so verification doesn't fail + m3 = m4.abs(); + VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m3))); + VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m3))); + VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m3))); VERIFY_IS_APPROX(m3.log(), log(m3)); VERIFY_IS_APPROX(m3.log1p(), log1p(m3)); VERIFY_IS_APPROX(m3.log10(), log10(m3)); @@ -383,23 +383,23 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(sin(m1.asin()), m1); VERIFY_IS_APPROX(cos(m1.acos()), m1); VERIFY_IS_APPROX(tan(m1.atan()), m1); - VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1))); - VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1))); - VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1)))); - VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0+exp(-m1)))); - VERIFY_IS_APPROX(arg(m1), ((m1<0).template cast())*std::acos(-1.0)); + VERIFY_IS_APPROX(sinh(m1), Scalar(0.5)*(exp(m1)-exp(-m1))); + VERIFY_IS_APPROX(cosh(m1), Scalar(0.5)*(exp(m1)+exp(-m1))); + VERIFY_IS_APPROX(tanh(m1), (Scalar(0.5)*(exp(m1)-exp(-m1)))/(Scalar(0.5)*(exp(m1)+exp(-m1)))); + VERIFY_IS_APPROX(logistic(m1), (Scalar(1)/(Scalar(1)+exp(-m1)))); + VERIFY_IS_APPROX(arg(m1), ((m1())*Scalar(std::acos(Scalar(-1)))); VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all()); VERIFY((rint(m1) <= ceil(m1) && rint(m1) >= floor(m1)).all()); VERIFY(((ceil(m1) - round(m1)) <= Scalar(0.5) || (round(m1) - floor(m1)) <= Scalar(0.5)).all()); VERIFY(((ceil(m1) - round(m1)) <= Scalar(1.0) && (round(m1) - floor(m1)) <= Scalar(1.0)).all()); VERIFY(((ceil(m1) - rint(m1)) <= Scalar(0.5) || (rint(m1) - floor(m1)) <= Scalar(0.5)).all()); VERIFY(((ceil(m1) - rint(m1)) <= Scalar(1.0) && (rint(m1) - floor(m1)) <= Scalar(1.0)).all()); - VERIFY((Eigen::isnan)((m1*0.0)/0.0).all()); - VERIFY((Eigen::isinf)(m4/0.0).all()); - VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*0.0/0.0)) && (!(Eigen::isfinite)(m4/0.0))).all()); - VERIFY_IS_APPROX(inverse(inverse(m1)),m1); + VERIFY((Eigen::isnan)((m1*Scalar(0))/Scalar(0)).all()); + VERIFY((Eigen::isinf)(m4/Scalar(0)).all()); + VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*Scalar(0)/Scalar(0))) && (!(Eigen::isfinite)(m4/Scalar(0)))).all()); + VERIFY_IS_APPROX(inverse(inverse(m4)),m4); VERIFY((abs(m1) == m1 || abs(m1) == -m1).all()); - VERIFY_IS_APPROX(m3, sqrt(abs2(m1))); + VERIFY_IS_APPROX(m3, sqrt(abs2(m3))); VERIFY_IS_APPROX(m1.absolute_difference(m2), (m1 > m2).select(m1 - m2, m2 - m1)); VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() ); VERIFY_IS_APPROX( m1*m1.sign(),m1.abs()); @@ -412,26 +412,29 @@ template void array_real(const ArrayType& m) // shift argument of logarithm so that it is not zero Scalar smallNumber = NumTraits::dummy_precision(); - VERIFY_IS_APPROX((m3 + smallNumber).log() , log(abs(m1) + smallNumber)); - VERIFY_IS_APPROX((m3 + smallNumber + 1).log() , log1p(abs(m1) + smallNumber)); + VERIFY_IS_APPROX((m3 + smallNumber).log() , log(abs(m3) + smallNumber)); + VERIFY_IS_APPROX((m3 + smallNumber + Scalar(1)).log() , log1p(abs(m3) + smallNumber)); VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2)); VERIFY_IS_APPROX(m1.exp(), exp(m1)); VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp()); VERIFY_IS_APPROX(m1.expm1(), expm1(m1)); - VERIFY_IS_APPROX((m3 + smallNumber).exp() - 1, expm1(abs(m3) + smallNumber)); + VERIFY_IS_APPROX((m3 + smallNumber).exp() - Scalar(1), expm1(abs(m3) + smallNumber)); VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt()); - VERIFY_IS_APPROX(m1.pow(RealScalar(-2)), m1.square().inverse()); + + // Avoid inf and NaN. + m3 = (m1.square()::epsilon()).select(Scalar(1),m3); + VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse()); pow_test(); - VERIFY_IS_APPROX(log10(m3), log(m3)/log(10)); - VERIFY_IS_APPROX(log2(m3), log(m3)/log(2)); + VERIFY_IS_APPROX(log10(m3), log(m3)/log(Scalar(10))); + VERIFY_IS_APPROX(log2(m3), log(m3)/log(Scalar(2))); // scalar by array division const RealScalar tiny = sqrt(std::numeric_limits::epsilon()); @@ -480,7 +483,7 @@ template void array_complex(const ArrayType& m) VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all()); VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all()); VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all()); - VERIFY_IS_APPROX(m1.inverse(), inverse(m1)); + VERIFY_IS_APPROX(m4.inverse(), inverse(m4)); VERIFY_IS_APPROX(m1.log(), log(m1)); VERIFY_IS_APPROX(m1.log10(), log10(m1)); VERIFY_IS_APPROX(m1.log2(), log2(m1)); @@ -534,7 +537,7 @@ template void array_complex(const ArrayType& m) VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*zero/zero)) && (!(Eigen::isfinite)(m1/zero))).all()); - VERIFY_IS_APPROX(inverse(inverse(m1)),m1); + VERIFY_IS_APPROX(inverse(inverse(m4)),m4); VERIFY_IS_APPROX(conj(m1.conjugate()), m1); VERIFY_IS_APPROX(abs(m1), sqrt(square(m1.real())+square(m1.imag()))); VERIFY_IS_APPROX(abs(m1), sqrt(abs2(m1))); @@ -622,6 +625,8 @@ EIGEN_DECLARE_TEST(array_cwise) CALL_SUBTEST_2( array_real(Array22f()) ); CALL_SUBTEST_3( array_real(Array44d()) ); CALL_SUBTEST_5( array_real(ArrayXXf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); + CALL_SUBTEST_7( array_real(Array()) ); + CALL_SUBTEST_8( array_real(Array()) ); } for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); -- cgit v1.2.3