diff options
Diffstat (limited to 'test/array.cpp')
-rw-r--r-- | test/array.cpp | 130 |
1 files changed, 125 insertions, 5 deletions
diff --git a/test/array.cpp b/test/array.cpp index 5395721f5..beaa62221 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -202,7 +202,7 @@ template<typename ArrayType> void array_real(const ArrayType& m) m2 = ArrayType::Random(rows, cols), m3(rows, cols), m4 = m1; - + m4 = (m4.abs()==Scalar(0)).select(1,m4); Scalar s1 = internal::random<Scalar>(); @@ -217,6 +217,12 @@ template<typename ArrayType> void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.sinh(), sinh(m1)); VERIFY_IS_APPROX(m1.cosh(), cosh(m1)); VERIFY_IS_APPROX(m1.tanh(), tanh(m1)); +#ifdef EIGEN_HAS_C99_MATH + VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1)); + VERIFY_IS_APPROX(m1.digamma(), digamma(m1)); + VERIFY_IS_APPROX(m1.erf(), erf(m1)); + VERIFY_IS_APPROX(m1.erfc(), erfc(m1)); +#endif // EIGEN_HAS_C99_MATH VERIFY_IS_APPROX(m1.arg(), arg(m1)); VERIFY_IS_APPROX(m1.round(), round(m1)); VERIFY_IS_APPROX(m1.floor(), floor(m1)); @@ -289,7 +295,6 @@ template<typename ArrayType> void array_real(const ArrayType& m) VERIFY_IS_APPROX(Eigen::pow(m1,2*exponents), m1.square().square()); VERIFY_IS_APPROX(m1.pow(2*exponents), m1.square().square()); VERIFY_IS_APPROX(pow(m1(0,0), exponents), ArrayType::Constant(rows,cols,m1(0,0)*m1(0,0))); - VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt()); @@ -304,7 +309,123 @@ template<typename ArrayType> void array_real(const ArrayType& m) s1 += Scalar(tiny); m1 += ArrayType::Constant(rows,cols,Scalar(tiny)); VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse()); - + +#ifdef EIGEN_HAS_C99_MATH + // check special functions (comparing against numpy implementation) + if (!NumTraits<Scalar>::IsComplex) { + VERIFY_IS_APPROX(numext::digamma(Scalar(1)), RealScalar(-0.5772156649015329)); + VERIFY_IS_APPROX(numext::digamma(Scalar(1.5)), RealScalar(0.03648997397857645)); + VERIFY_IS_APPROX(numext::digamma(Scalar(4)), RealScalar(1.2561176684318)); + VERIFY_IS_APPROX(numext::digamma(Scalar(-10.5)), RealScalar(2.398239129535781)); + VERIFY_IS_APPROX(numext::digamma(Scalar(10000.5)), RealScalar(9.210340372392849)); + VERIFY_IS_EQUAL(numext::digamma(Scalar(0)), + std::numeric_limits<RealScalar>::infinity()); + VERIFY_IS_EQUAL(numext::digamma(Scalar(-1)), + std::numeric_limits<RealScalar>::infinity()); + + // Check the zeta function against scipy.special.zeta + VERIFY_IS_APPROX(numext::zeta(Scalar(1.5), Scalar(2)), RealScalar(1.61237534869)); + VERIFY_IS_APPROX(numext::zeta(Scalar(4), Scalar(1.5)), RealScalar(0.234848505667)); + VERIFY_IS_APPROX(numext::zeta(Scalar(10.5), Scalar(3)), RealScalar(1.03086757337e-5)); + VERIFY_IS_APPROX(numext::zeta(Scalar(10000.5), Scalar(1.0001)), RealScalar(0.367879440865)); + VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097)); + VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter + std::numeric_limits<RealScalar>::infinity()); + VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter + + // Check the polygamma against scipy.special.polygamma examples + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496)); + VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter + + // Check the polygamma function over a larger range of values + VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(31), Scalar(11.8)), RealScalar(0.445487887616)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(28), Scalar(17.7)), RealScalar(-2.47810300902e-07)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(8), Scalar(30.2)), RealScalar(-8.29668781082e-09)); + /* The following tests only pass for doubles because floats cannot handle the large values of + the gamma function. + VERIFY_IS_APPROX(numext::polygamma(Scalar(42), Scalar(15.8)), RealScalar(-0.434562276666)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(147), Scalar(54.1)), RealScalar(0.567742190178)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(170), Scalar(64)), RealScalar(-0.0108615497927)); + */ + + { + // Test various propreties of igamma & igammac. These are normalized + // gamma integrals where + // igammac(a, x) = Gamma(a, x) / Gamma(a) + // igamma(a, x) = gamma(a, x) / Gamma(a) + // where Gamma and gamma are considered the standard unnormalized + // upper and lower incomplete gamma functions, respectively. + ArrayType a = m1.abs() + 2; + ArrayType x = m2.abs() + 2; + ArrayType zero = ArrayType::Zero(rows, cols); + ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0)); + ArrayType a_m1 = a - one; + ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp(); + ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp(); + ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp(); + ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp(); + + // Gamma(a, 0) == Gamma(a) + VERIFY_IS_APPROX(Eigen::igammac(a, zero), one); + + // Gamma(a, x) + gamma(a, x) == Gamma(a) + VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp()); + + // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x) + VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp()); + + // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x) + VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp()); + } + + // Check exact values of igamma and igammac against a third party calculation. + Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)}; + Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)}; + + // location i*6+j corresponds to a_s[i], x_s[j]. + Scalar nan = std::numeric_limits<Scalar>::quiet_NaN(); + Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan}, + {0.0, 0.6321205588285578, 0.7768698398515702, + 0.9816843611112658, 9.999500016666262e-05, 1.0}, + {0.0, 0.4275932955291202, 0.608374823728911, + 0.9539882943107686, 7.522076445089201e-07, 1.0}, + {0.0, 0.01898815687615381, 0.06564245437845008, + 0.5665298796332909, 4.166333347221828e-18, 1.0}, + {0.0, 0.9999780593618628, 0.9999899967080838, + 0.9999996219837988, 0.9991370418689945, 1.0}, + {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}}; + Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan}, + {1.0, 0.36787944117144233, 0.22313016014842982, + 0.018315638888734182, 0.9999000049998333, 0.0}, + {1.0, 0.5724067044708798, 0.3916251762710878, + 0.04601170568923136, 0.9999992477923555, 0.0}, + {1.0, 0.9810118431238462, 0.9343575456215499, + 0.4334701203667089, 1.0, 0.0}, + {1.0, 2.1940638138146658e-05, 1.0003291916285e-05, + 3.7801620118431334e-07, 0.0008629581310054535, + 0.0}, + {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}}; + for (int i = 0; i < 6; ++i) { + for (int j = 0; j < 6; ++j) { + if ((std::isnan)(igamma_s[i][j])) { + VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j]))); + } else { + VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]); + } + + if ((std::isnan)(igammac_s[i][j])) { + VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j]))); + } else { + VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]); + } + } + } + } +#endif // EIGEN_HAS_C99_MATH + // check inplace transpose m3 = m1; m3.transposeInPlace(); @@ -331,8 +452,6 @@ template<typename ArrayType> void array_complex(const ArrayType& m) Array<RealScalar, -1, -1> m3(rows, cols); - Scalar s1 = internal::random<Scalar>(); - for (Index i = 0; i < m.rows(); ++i) for (Index j = 0; j < m.cols(); ++j) m2(i,j) = sqrt(m1(i,j)); @@ -405,6 +524,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m) VERIFY_IS_APPROX( m1.sign() * m1.abs(), m1); // scalar by array division + Scalar s1 = internal::random<Scalar>(); const RealScalar tiny = sqrt(std::numeric_limits<RealScalar>::epsilon()); s1 += Scalar(tiny); m1 += ArrayType::Constant(rows,cols,Scalar(tiny)); |