aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2016-05-19 18:34:41 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2016-05-19 18:34:41 +0200
commitccb408ee6afe34957081f85be4e8471d5270c6cc (patch)
tree02f0c9b3171c8fec21c23d9f29ab55bc79a52eeb /test
parent6761c64d60f297d429a502dbf064b36b6dfb6c9b (diff)
Improve unit tests of zeta, polygamma, and digamma
Diffstat (limited to 'test')
-rw-r--r--test/array.cpp112
1 files changed, 74 insertions, 38 deletions
diff --git a/test/array.cpp b/test/array.cpp
index 1f4afc1c6..cb80c887a 100644
--- a/test/array.cpp
+++ b/test/array.cpp
@@ -310,46 +310,12 @@ template<typename ArrayType> void array_real(const ArrayType& m)
m1 += ArrayType::Constant(rows,cols,Scalar(tiny));
VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse());
+
+
#ifdef EIGEN_HAS_C99_MATH
// check special functions (comparing against numpy implementation)
- if (!NumTraits<Scalar>::IsComplex) {
- VERIFY_IS_APPROX(numext::digamma(Scalar(1)), RealScalar(-0.5772156649015329));
- VERIFY_IS_APPROX(numext::digamma(Scalar(1.5)), RealScalar(0.03648997397857645));
- VERIFY_IS_APPROX(numext::digamma(Scalar(4)), RealScalar(1.2561176684318));
- VERIFY_IS_APPROX(numext::digamma(Scalar(-10.5)), RealScalar(2.398239129535781));
- VERIFY_IS_APPROX(numext::digamma(Scalar(10000.5)), RealScalar(9.210340372392849));
- VERIFY_IS_EQUAL(numext::digamma(Scalar(0)),
- std::numeric_limits<RealScalar>::infinity());
- VERIFY_IS_EQUAL(numext::digamma(Scalar(-1)),
- std::numeric_limits<RealScalar>::infinity());
-
- // Check the zeta function against scipy.special.zeta
- VERIFY_IS_APPROX(numext::zeta(Scalar(1.5), Scalar(2)), RealScalar(1.61237534869));
- VERIFY_IS_APPROX(numext::zeta(Scalar(4), Scalar(1.5)), RealScalar(0.234848505667));
- VERIFY_IS_APPROX(numext::zeta(Scalar(10.5), Scalar(3)), RealScalar(1.03086757337e-5));
- VERIFY_IS_APPROX(numext::zeta(Scalar(10000.5), Scalar(1.0001)), RealScalar(0.367879440865));
- VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
- VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
- std::numeric_limits<RealScalar>::infinity());
- VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter
-
- // Check the polygamma against scipy.special.polygamma examples
- VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
- VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter
-
- // Check the polygamma function over a larger range of values
- VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(31), Scalar(11.8)), RealScalar(0.445487887616));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(28), Scalar(17.7)), RealScalar(-2.47810300902e-07));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(8), Scalar(30.2)), RealScalar(-8.29668781082e-09));
- /* The following tests only pass for doubles because floats cannot handle the large values of
- the gamma function.
- VERIFY_IS_APPROX(numext::polygamma(Scalar(42), Scalar(15.8)), RealScalar(-0.434562276666));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(147), Scalar(54.1)), RealScalar(0.567742190178));
- VERIFY_IS_APPROX(numext::polygamma(Scalar(170), Scalar(64)), RealScalar(-0.0108615497927));
- */
+ if (!NumTraits<Scalar>::IsComplex)
+ {
{
// Test various propreties of igamma & igammac. These are normalized
@@ -568,6 +534,73 @@ template<typename ArrayType> void min_max(const ArrayType& m)
}
+template<typename X, typename Y>
+void verify_component_wise(const X& x, const Y& y)
+{
+ for(Index i=0; i<x.size(); ++i)
+ {
+ if((numext::isfinite)(y(i)))
+ VERIFY_IS_APPROX( x(i), y(i) );
+ else if((numext::isnan)(y(i)))
+ VERIFY((numext::isnan)(x(i)));
+ else
+ VERIFY_IS_EQUAL( x(i), y(i) );
+ }
+}
+
+// check special functions (comparing against numpy implementation)
+template<typename ArrayType> void array_special_functions()
+{
+ using std::abs;
+ using std::sqrt;
+ typedef typename ArrayType::Scalar Scalar;
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+
+ Scalar plusinf = std::numeric_limits<Scalar>::infinity();
+ Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
+
+ // Check the zeta function against scipy.special.zeta
+ {
+ ArrayType x(7), q(7), res(7), ref(7);
+ x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9;
+ q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345;
+ ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan;
+ CALL_SUBTEST( verify_component_wise(ref, ref); );
+ CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
+ CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
+ }
+
+ // digamma
+ {
+ ArrayType x(7), res(7), ref(7);
+ x << 1, 1.5, 4, -10.5, 10000.5, 0, -1;
+ ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf;
+ CALL_SUBTEST( verify_component_wise(ref, ref); );
+
+ CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
+ CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
+ }
+
+
+ {
+ ArrayType n(11), x(11), res(11), ref(11);
+ n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170;
+ x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64;
+ ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
+ CALL_SUBTEST( verify_component_wise(ref, ref); );
+
+ if(sizeof(RealScalar)>=64) {
+// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
+ CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
+ }
+ else {
+// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
+ CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
+ }
+
+ }
+}
+
void test_array()
{
for(int i = 0; i < g_repeat; i++) {
@@ -609,4 +642,7 @@ void test_array()
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Xpr>::type,
ArrayBase<Xpr>
>::value));
+
+ CALL_SUBTEST_7(array_special_functions<ArrayXf>());
+ CALL_SUBTEST_7(array_special_functions<ArrayXd>());
}