diff options
author | Till Hoffmann <tillahoffmann@gmail.com> | 2016-04-01 14:35:21 +0100 |
---|---|---|
committer | Till Hoffmann <tillahoffmann@gmail.com> | 2016-04-01 14:35:21 +0100 |
commit | 57239f4a8149dbd603ad376e90a0a4574b846710 (patch) | |
tree | 4ff7041da6cd7121c19271bfb385c813ea66e0a8 | |
parent | dd5d390daf3a3a561a772b64f1b602e5f240bf8b (diff) |
Added polygamma function.
-rw-r--r-- | Eigen/src/Core/GenericPacketMath.h | 5 | ||||
-rw-r--r-- | Eigen/src/Core/GlobalFunctions.h | 1 | ||||
-rw-r--r-- | Eigen/src/Core/SpecialFunctions.h | 52 | ||||
-rw-r--r-- | Eigen/src/Core/arch/CUDA/MathFunctions.h | 22 | ||||
-rw-r--r-- | Eigen/src/Core/arch/CUDA/PacketMath.h | 1 | ||||
-rw-r--r-- | Eigen/src/Core/functors/UnaryFunctors.h | 22 | ||||
-rw-r--r-- | Eigen/src/plugins/ArrayCwiseUnaryOps.h | 9 | ||||
-rw-r--r-- | test/array.cpp | 5 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 6 |
9 files changed, 118 insertions, 5 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 988fc9c99..6ff61c18a 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -77,6 +77,7 @@ struct default_packet_traits HasLGamma = 0, HasDiGamma = 0, HasZeta = 0, + HasPolygamma = 0, HasErf = 0, HasErfc = 0, HasIGamma = 0, @@ -456,6 +457,10 @@ Packet pdigamma(const Packet& a) { using numext::digamma; return digamma(a); } template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pzeta(const Packet& x, const Packet& q) { using numext::zeta; return zeta(x, q); } +/** \internal \returns the polygamma function (coeff-wise) */ +template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet ppolygamma(const Packet& n, const Packet& x) { using numext::polygamma; return polygamma(n, x); } + /** \internal \returns the erf(\a a) (coeff-wise) */ template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet perf(const Packet& a) { using numext::erf; return erf(a); } diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index a013cca1f..05ba6ddb4 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -52,6 +52,7 @@ namespace Eigen EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(zeta,scalar_zeta_op) + EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(polygamma,scalar_polygamma_op) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erfc,scalar_erfc_op) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(exp,scalar_exp_op) diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index 4240ebf2f..02ac7cf3f 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -736,7 +736,7 @@ struct zeta_retval { template <typename Scalar> struct zeta_impl { EIGEN_DEVICE_FUNC - static Scalar run(Scalar x) { + static Scalar run(Scalar x, Scalar q) { EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED); return Scalar(0); @@ -905,6 +905,50 @@ struct zeta_impl { #endif // EIGEN_HAS_C99_MATH +/**************************************************************************** + * Implementation of polygamma function * + ****************************************************************************/ + +template <typename Scalar> +struct polygamma_retval { + typedef Scalar type; +}; + +#ifndef EIGEN_HAS_C99_MATH + +template <typename Scalar> +struct polygamma_impl { + EIGEN_DEVICE_FUNC + static Scalar run(Scalar n, Scalar x) { + EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), + THIS_TYPE_IS_NOT_SUPPORTED); + return Scalar(0); + } +}; + +#else + +template <typename Scalar> +struct polygamma_impl { + EIGEN_DEVICE_FUNC + static Scalar run(Scalar n, Scalar x) { + Scalar zero = 0.0, one = 1.0; + Scalar nplus = n + one; + + // Just return the digamma function for n = 1 + if (n == zero) { + return digamma_impl<Scalar>::run(x); + } + // Use the same implementation as scipy + else { + Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus)); + return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x); + } + } +}; + +#endif // EIGEN_HAS_C99_MATH + } // end namespace internal namespace numext { @@ -928,6 +972,12 @@ zeta(const Scalar& x, const Scalar& q) { } template <typename Scalar> +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar) +polygamma(const Scalar& n, const Scalar& x) { + return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x); +} + +template <typename Scalar> EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar) erf(const Scalar& x) { return EIGEN_MATHFUNC_IMPL(erf, Scalar)::run(x); diff --git a/Eigen/src/Core/arch/CUDA/MathFunctions.h b/Eigen/src/Core/arch/CUDA/MathFunctions.h index 858775523..317499b29 100644 --- a/Eigen/src/Core/arch/CUDA/MathFunctions.h +++ b/Eigen/src/Core/arch/CUDA/MathFunctions.h @@ -93,17 +93,31 @@ double2 pdigamma<double2>(const double2& a) } template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -float4 pzeta<float4>(const float4& a) +float4 pzeta<float4>(const float4& x, const float4& q) { using numext::zeta; - return make_float4(zeta(a.x), zeta(a.y), zeta(a.z), zeta(a.w)); + return make_float4(zeta(x.x, q.x), zeta(x.y, q.y), zeta(x.z, q.z), zeta(x.w, q.w)); } template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -double2 pzeta<double2>(const double2& a) +double2 pzeta<double2>(const double2& x, const double2& q) { using numext::zeta; - return make_double2(zeta(a.x), zeta(a.y)); + return make_double2(zeta(x.x, q.x), zeta(x.y, q.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 ppolygamma<float4>(const float4& n, const float4& x) +{ + using numext::polygamma; + return make_float4(polygamma(n.x, x.x), polygamma(n.y, x.y), polygamma(n.z, x.z), polygamma(n.w, x.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 ppolygamma<double2>(const double2& n, const double2& x) +{ + using numext::polygamma; + return make_double2(polygamma(n.x, x.x), polygamma(n.y, x.y)); } template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/Eigen/src/Core/arch/CUDA/PacketMath.h b/Eigen/src/Core/arch/CUDA/PacketMath.h index e0db18fbf..932df1092 100644 --- a/Eigen/src/Core/arch/CUDA/PacketMath.h +++ b/Eigen/src/Core/arch/CUDA/PacketMath.h @@ -41,6 +41,7 @@ template<> struct packet_traits<float> : default_packet_traits HasLGamma = 1, HasDiGamma = 1, HasZeta = 1, + HasPolygamma = 1, HasErf = 1, HasErfc = 1, HasIgamma = 1, diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index e2fb8d8d6..826d84f69 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -472,6 +472,28 @@ struct functor_traits<scalar_zeta_op<Scalar> > }; /** \internal + * \brief Template functor to compute the polygamma function. + * \sa class CwiseUnaryOp, Cwise::polygamma() + */ +template<typename Scalar> struct scalar_polygamma_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op) + EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& n, const Scalar& x) const { + using numext::polygamma; return polygamma(n, x); + } + typedef typename packet_traits<Scalar>::type Packet; + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); } +}; +template<typename Scalar> +struct functor_traits<scalar_polygamma_op<Scalar> > +{ + enum { + // Guesstimate + Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost, + PacketAccess = packet_traits<Scalar>::HasPolygamma + }; +}; + +/** \internal * \brief Template functor to compute the Gauss error function of a * scalar * \sa class CwiseUnaryOp, Cwise::erf() diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/Eigen/src/plugins/ArrayCwiseUnaryOps.h index 6c92a2f1b..56c71172c 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.h @@ -24,6 +24,7 @@ typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturn typedef CwiseUnaryOp<internal::scalar_lgamma_op<Scalar>, const Derived> LgammaReturnType; typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType; typedef CwiseUnaryOp<internal::scalar_zeta_op<Scalar>, const Derived> ZetaReturnType; +typedef CwiseUnaryOp<internal::scalar_polygamma_op<Scalar>, const Derived> PolygammaReturnType; typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType; typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType; typedef CwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> PowReturnType; @@ -338,6 +339,14 @@ zeta() const return ZetaReturnType(derived()); } +/** \returns an expression of the coefficient-wise polygamma function. + */ +inline const PolygammaReturnType +polygamma() const +{ + return PolygammaReturnType(derived()); +} + /** \returns an expression of the coefficient-wise Gauss error * function of *this. * diff --git a/test/array.cpp b/test/array.cpp index 2f0b0f1b6..56d196923 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -331,6 +331,11 @@ template<typename ArrayType> void array_real(const ArrayType& m) VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097)); VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter std::numeric_limits<RealScalar>::infinity()); + + // Check the polygamma against scipy.special.polygamma + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848)); + VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496)); { // Test various propreties of igamma & igammac. These are normalized diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 7c427d3c1..65b969aaf 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -138,6 +138,12 @@ class TensorBase<Derived, ReadOnlyAccessors> zeta() const { return unaryExpr(internal::scalar_zeta_op<Scalar>()); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_polygamma_op<Scalar>, const Derived> + polygamma() const { + return unaryExpr(internal::scalar_polygamma_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> |