From f7362772e3236cdb8ae4d9be175f83a0b19902a0 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 24 Dec 2015 21:15:38 -0800 Subject: Add digamma for CPU + CUDA. Includes tests. --- test/array.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/test/array.cpp b/test/array.cpp index 6adedfb06..9366b73fd 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -219,6 +219,7 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.tanh(), tanh(m1)); #ifdef EIGEN_HAS_C99_MATH VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1)); + VERIFY_IS_APPROX(m1.digamma(), digamma(m1)); VERIFY_IS_APPROX(m1.erf(), erf(m1)); VERIFY_IS_APPROX(m1.erfc(), erfc(m1)); #endif // EIGEN_HAS_C99_MATH @@ -309,7 +310,20 @@ template void array_real(const ArrayType& m) s1 += Scalar(tiny); m1 += ArrayType::Constant(rows,cols,Scalar(tiny)); VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse()); - + + // check special functions (comparing against numpy implementation) + if (!NumTraits::IsComplex) { + VERIFY_IS_APPROX(numext::digamma(Scalar(1)), RealScalar(-0.5772156649015329)); + VERIFY_IS_APPROX(numext::digamma(Scalar(1.5)), RealScalar(0.03648997397857645)); + VERIFY_IS_APPROX(numext::digamma(Scalar(4)), RealScalar(1.2561176684318)); + VERIFY_IS_APPROX(numext::digamma(Scalar(-10.5)), RealScalar(2.398239129535781)); + VERIFY_IS_APPROX(numext::digamma(Scalar(10000.5)), RealScalar(9.210340372392849)); + VERIFY_IS_EQUAL(numext::digamma(Scalar(0)), + std::numeric_limits::infinity()); + VERIFY_IS_EQUAL(numext::digamma(Scalar(-1)), + std::numeric_limits::infinity()); + } + // check inplace transpose m3 = m1; m3.transposeInPlace(); @@ -336,8 +350,6 @@ template void array_complex(const ArrayType& m) Array m3(rows, cols); - Scalar s1 = internal::random(); - for (Index i = 0; i < m.rows(); ++i) for (Index j = 0; j < m.cols(); ++j) m2(i,j) = sqrt(m1(i,j)); @@ -410,6 +422,7 @@ template void array_complex(const ArrayType& m) VERIFY_IS_APPROX( m1.sign() * m1.abs(), m1); // scalar by array division + Scalar s1 = internal::random(); const RealScalar tiny = sqrt(std::numeric_limits::epsilon()); s1 += Scalar(tiny); m1 += ArrayType::Constant(rows,cols,Scalar(tiny)); -- cgit v1.2.3