aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Figurnov <mfigurnov@google.com>2018-06-06 18:49:26 +0100
committerGravatar Michael Figurnov <mfigurnov@google.com>2018-06-06 18:49:26 +0100
commit4bd158fa37b4bba74e6421575d5c69eeea547172 (patch)
tree940bd1497831563a1792aea863ce9e2a9afd0b45
parente206f8d4a401fe2060bada4d4b5d92e3bf3b561c (diff)
Derivative of the incomplete Gamma function and the sample of a Gamma random variable.
In addition to igamma(a, x), this code implements: * igamma_der_a(a, x) = d igamma(a, x) / da -- derivative of igamma with respect to the parameter * gamma_sample_der_alpha(alpha, sample) -- reparameterization derivative of a Gamma(alpha, 1) random variable sample with respect to the alpha parameter The derivatives are computed by forward mode differentiation of the igamma(a, x) code. Although gamma_sample_der_alpha can be implemented via igamma_der_a, a separate function is more accurate and efficient due to analytical cancellation of some terms. All three functions are implemented by a method parameterized with "mode" that always computes the derivatives, but does not return them unless required by the mode. The compiler is expected to (and, based on benchmarks, does) skip the unnecessary computations depending on the mode.
-rw-r--r--Eigen/src/Core/GenericPacketMath.h2
-rw-r--r--Eigen/src/Core/arch/CUDA/PacketMath.h4
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h14
-rw-r--r--unsupported/Eigen/SpecialFunctions2
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h42
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h54
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h8
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h561
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h15
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h35
-rw-r--r--unsupported/test/cxx11_tensor_cuda.cu157
-rw-r--r--unsupported/test/special_functions.cpp92
12 files changed, 785 insertions, 201 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 888a3f7ea..55b6a89e2 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -85,6 +85,8 @@ struct default_packet_traits
HasI0e = 0,
HasI1e = 0,
HasIGamma = 0,
+ HasIGammaDerA = 0,
+ HasGammaSampleDerAlpha = 0,
HasIGammac = 0,
HasBetaInc = 0,
diff --git a/Eigen/src/Core/arch/CUDA/PacketMath.h b/Eigen/src/Core/arch/CUDA/PacketMath.h
index 704a4e0d9..ab8e477f4 100644
--- a/Eigen/src/Core/arch/CUDA/PacketMath.h
+++ b/Eigen/src/Core/arch/CUDA/PacketMath.h
@@ -47,6 +47,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasI0e = 1,
HasI1e = 1,
HasIGamma = 1,
+ HasIGammaDerA = 1,
+ HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,
@@ -78,6 +80,8 @@ template<> struct packet_traits<double> : default_packet_traits
HasI0e = 1,
HasI1e = 1,
HasIGamma = 1,
+ HasIGammaDerA = 1,
+ HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index a942c98dd..88b0af0d3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -152,6 +152,20 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), internal::scalar_igamma_op<Scalar>());
}
+ // igamma_der_a(a = this, x = other)
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_igamma_der_a_op<Scalar>, const Derived, const OtherDerived>
+ igamma_der_a(const OtherDerived& other) const {
+ return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op<Scalar>());
+ }
+
+ // gamma_sample_der_alpha(alpha = this, sample = other)
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_gamma_sample_der_alpha_op<Scalar>, const Derived, const OtherDerived>
+ gamma_sample_der_alpha(const OtherDerived& other) const {
+ return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op<Scalar>());
+ }
+
// igammac(a = this, x = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_igammac_op<Scalar>, const Derived, const OtherDerived>
diff --git a/unsupported/Eigen/SpecialFunctions b/unsupported/Eigen/SpecialFunctions
index 482ec6e6f..9441ba8f5 100644
--- a/unsupported/Eigen/SpecialFunctions
+++ b/unsupported/Eigen/SpecialFunctions
@@ -29,6 +29,8 @@ namespace Eigen {
* - erfc
* - lgamma
* - igamma
+ * - igamma_der_a
+ * - gamma_sample_der_alpha
* - igammac
* - digamma
* - polygamma
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h
index b7a9d035b..30cdf4751 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h
@@ -33,6 +33,48 @@ igamma(const Eigen::ArrayBase<Derived>& a, const Eigen::ArrayBase<ExponentDerive
);
}
+/** \cpp11 \returns an expression of the coefficient-wise igamma_der_a(\a a, \a x) to the given arrays.
+ *
+ * This function computes the coefficient-wise derivative of the incomplete
+ * gamma function with respect to the parameter a.
+ *
+ * \note This function supports only float and double scalar types in c++11
+ * mode. To support other scalar types,
+ * or float/double in non c++11 mode, the user has to provide implementations
+ * of igamma_der_a(T,T) for any scalar
+ * type T to be supported.
+ *
+ * \sa Eigen::igamma(), Eigen::lgamma()
+ */
+template <typename Derived, typename ExponentDerived>
+inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_igamma_der_a_op<typename Derived::Scalar>, const Derived, const ExponentDerived>
+igamma_der_a(const Eigen::ArrayBase<Derived>& a, const Eigen::ArrayBase<ExponentDerived>& x) {
+ return Eigen::CwiseBinaryOp<Eigen::internal::scalar_igamma_der_a_op<typename Derived::Scalar>, const Derived, const ExponentDerived>(
+ a.derived(),
+ x.derived());
+}
+
+/** \cpp11 \returns an expression of the coefficient-wise gamma_sample_der_alpha(\a alpha, \a sample) to the given arrays.
+ *
+ * This function computes the coefficient-wise derivative of the sample
+ * of a Gamma(alpha, 1) random variable with respect to the parameter alpha.
+ *
+ * \note This function supports only float and double scalar types in c++11
+ * mode. To support other scalar types,
+ * or float/double in non c++11 mode, the user has to provide implementations
+ * of gamma_sample_der_alpha(T,T) for any scalar
+ * type T to be supported.
+ *
+ * \sa Eigen::igamma(), Eigen::lgamma()
+ */
+template <typename AlphaDerived, typename SampleDerived>
+inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_gamma_sample_der_alpha_op<typename AlphaDerived::Scalar>, const AlphaDerived, const SampleDerived>
+gamma_sample_der_alpha(const Eigen::ArrayBase<AlphaDerived>& alpha, const Eigen::ArrayBase<SampleDerived>& sample) {
+ return Eigen::CwiseBinaryOp<Eigen::internal::scalar_gamma_sample_der_alpha_op<typename AlphaDerived::Scalar>, const AlphaDerived, const SampleDerived>(
+ alpha.derived(),
+ sample.derived());
+}
+
/** \cpp11 \returns an expression of the coefficient-wise igammac(\a a, \a x) to the given arrays.
*
* This function computes the coefficient-wise complementary incomplete gamma function.
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h
index 8420f0174..3a63dcdd6 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h
@@ -41,6 +41,60 @@ struct functor_traits<scalar_igamma_op<Scalar> > {
};
};
+/** \internal
+ * \brief Template functor to compute the derivative of the incomplete gamma
+ * function igamma_der_a(a, x)
+ *
+ * \sa class CwiseBinaryOp, Cwise::igamma_der_a
+ */
+template <typename Scalar>
+struct scalar_igamma_der_a_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
+ using numext::igamma_der_a;
+ return igamma_der_a(a, x);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
+ return internal::pigamma_der_a(a, x);
+ }
+};
+template <typename Scalar>
+struct functor_traits<scalar_igamma_der_a_op<Scalar> > {
+ enum {
+ // 2x the cost of igamma
+ Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost,
+ PacketAccess = packet_traits<Scalar>::HasIGammaDerA
+ };
+};
+
+/** \internal
+ * \brief Template functor to compute the derivative of the sample
+ * of a Gamma(alpha, 1) random variable with respect to the parameter alpha
+ * gamma_sample_der_alpha(alpha, sample)
+ *
+ * \sa class CwiseBinaryOp, Cwise::gamma_sample_der_alpha
+ */
+template <typename Scalar>
+struct scalar_gamma_sample_der_alpha_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const {
+ using numext::gamma_sample_der_alpha;
+ return gamma_sample_der_alpha(alpha, sample);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const {
+ return internal::pgamma_sample_der_alpha(alpha, sample);
+ }
+};
+template <typename Scalar>
+struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > {
+ enum {
+ // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out)
+ Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost,
+ PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha
+ };
+};
/** \internal
* \brief Template functor to compute the complementary incomplete gamma function igammac(a, x)
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h
index c5867002e..fbdfd299e 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h
@@ -33,6 +33,14 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erfc(const Eigen::h
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x)));
}
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma_der_a(const Eigen::half& a, const Eigen::half& x) {
+ return Eigen::half(Eigen::numext::igamma_der_a(static_cast<float>(a), static_cast<float>(x)));
+}
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half gamma_sample_der_alpha(const Eigen::half& alpha, const Eigen::half& sample) {
+ return Eigen::half(Eigen::numext::gamma_sample_der_alpha(static_cast<float>(alpha), static_cast<float>(sample)));
+}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
}
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
index 6c7ac3f3b..b24df2a95 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -521,6 +521,198 @@ struct cephes_helper<double> {
}
};
+enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
+
+template <typename Scalar, IgammaComputationMode mode>
+EIGEN_DEVICE_FUNC
+int igamma_num_iterations() {
+ /* Returns the maximum number of internal iterations for igamma computation.
+ */
+ if (mode == VALUE) {
+ return 2000;
+ }
+
+ if (internal::is_same<Scalar, float>::value) {
+ return 200;
+ } else if (internal::is_same<Scalar, double>::value) {
+ return 500;
+ } else {
+ return 2000;
+ }
+}
+
+template <typename Scalar, IgammaComputationMode mode>
+struct igammac_cf_impl {
+ /* Computes igamc(a, x) or derivative (depending on the mode)
+ * using the continued fraction expansion of the complementary
+ * incomplete Gamma function.
+ *
+ * Preconditions:
+ * a > 0
+ * x >= 1
+ * x >= a
+ */
+ EIGEN_DEVICE_FUNC
+ static Scalar run(Scalar a, Scalar x) {
+ const Scalar zero = 0;
+ const Scalar one = 1;
+ const Scalar two = 2;
+ const Scalar machep = cephes_helper<Scalar>::machep();
+ const Scalar big = cephes_helper<Scalar>::big();
+ const Scalar biginv = cephes_helper<Scalar>::biginv();
+
+ if ((numext::isinf)(x)) {
+ return zero;
+ }
+
+ // continued fraction
+ Scalar y = one - a;
+ Scalar z = x + y + one;
+ Scalar c = zero;
+ Scalar pkm2 = one;
+ Scalar qkm2 = x;
+ Scalar pkm1 = x + one;
+ Scalar qkm1 = z * x;
+ Scalar ans = pkm1 / qkm1;
+
+ Scalar dpkm2_da = zero;
+ Scalar dqkm2_da = zero;
+ Scalar dpkm1_da = zero;
+ Scalar dqkm1_da = -x;
+ Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
+
+ for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
+ c += one;
+ y += one;
+ z += two;
+
+ Scalar yc = y * c;
+ Scalar pk = pkm1 * z - pkm2 * yc;
+ Scalar qk = qkm1 * z - qkm2 * yc;
+
+ Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
+ Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
+
+ if (qk != zero) {
+ Scalar ans_prev = ans;
+ ans = pk / qk;
+
+ Scalar dans_da_prev = dans_da;
+ dans_da = (dpk_da - ans * dqk_da) / qk;
+
+ if (mode == VALUE) {
+ if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
+ break;
+ }
+ } else {
+ if (numext::abs(dans_da - dans_da_prev) <=
+ machep * numext::abs(dans_da)) {
+ break;
+ }
+ }
+ }
+
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ dpkm2_da = dpkm1_da;
+ dpkm1_da = dpk_da;
+ dqkm2_da = dqkm1_da;
+ dqkm1_da = dqk_da;
+
+ if (numext::abs(pk) > big) {
+ pkm2 *= biginv;
+ pkm1 *= biginv;
+ qkm2 *= biginv;
+ qkm1 *= biginv;
+
+ dpkm2_da *= biginv;
+ dpkm1_da *= biginv;
+ dqkm2_da *= biginv;
+ dqkm1_da *= biginv;
+ }
+ }
+
+ /* Compute x**a * exp(-x) / gamma(a) */
+ Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
+ Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
+ Scalar ax = numext::exp(logax);
+ Scalar dax_da = ax * dlogax_da;
+
+ switch (mode) {
+ case VALUE:
+ return ans * ax;
+ case DERIVATIVE:
+ return ans * dax_da + dans_da * ax;
+ case SAMPLE_DERIVATIVE:
+ return -(dans_da + ans * dlogax_da) * x;
+ }
+ }
+};
+
+template <typename Scalar, IgammaComputationMode mode>
+struct igamma_series_impl {
+ /* Computes igam(a, x) or its derivative (depending on the mode)
+ * using the series expansion of the incomplete Gamma function.
+ *
+ * Preconditions:
+ * x > 0
+ * a > 0
+ * !(x > 1 && x > a)
+ */
+ EIGEN_DEVICE_FUNC
+ static Scalar run(Scalar a, Scalar x) {
+ const Scalar zero = 0;
+ const Scalar one = 1;
+ const Scalar machep = cephes_helper<Scalar>::machep();
+
+ /* power series */
+ Scalar r = a;
+ Scalar c = one;
+ Scalar ans = one;
+
+ Scalar dc_da = zero;
+ Scalar dans_da = zero;
+
+ for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
+ r += one;
+ Scalar term = x / r;
+ Scalar dterm_da = -x / (r * r);
+ dc_da = term * dc_da + dterm_da * c;
+ dans_da += dc_da;
+ c *= term;
+ ans += c;
+
+ if (mode == VALUE) {
+ if (c <= machep * ans) {
+ break;
+ }
+ } else {
+ if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
+ break;
+ }
+ }
+ }
+
+ /* Compute x**a * exp(-x) / gamma(a + 1) */
+ Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a + one);
+ Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
+ Scalar ax = numext::exp(logax);
+ Scalar dax_da = ax * dlogax_da;
+
+ switch (mode) {
+ case VALUE:
+ return ans * ax;
+ case DERIVATIVE:
+ return ans * dax_da + dans_da * ax;
+ case SAMPLE_DERIVATIVE:
+ return -(dans_da + ans * dlogax_da) * x / a;
+ }
+ }
+};
+
#if !EIGEN_HAS_C99_MATH
template <typename Scalar>
@@ -535,8 +727,6 @@ struct igammac_impl {
#else
-template <typename Scalar> struct igamma_impl; // predeclare igamma_impl
-
template <typename Scalar>
struct igammac_impl {
EIGEN_DEVICE_FUNC
@@ -604,97 +794,15 @@ struct igammac_impl {
return nan;
}
- if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
+ if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
return nan;
}
if ((x < one) || (x < a)) {
- /* The checks above ensure that we meet the preconditions for
- * igamma_impl::Impl(), so call it, rather than igamma_impl::Run().
- * Calling Run() would also work, but in that case the compiler may not be
- * able to prove that igammac_impl::Run and igamma_impl::Run are not
- * mutually recursive. This leads to worse code, particularly on
- * platforms like nvptx, where recursion is allowed only begrudgingly.
- */
- return (one - igamma_impl<Scalar>::Impl(a, x));
- }
-
- return Impl(a, x);
- }
-
- private:
- /* igamma_impl calls igammac_impl::Impl. */
- friend struct igamma_impl<Scalar>;
-
- /* Actually computes igamc(a, x).
- *
- * Preconditions:
- * a > 0
- * x >= 1
- * x >= a
- */
- EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
- const Scalar zero = 0;
- const Scalar one = 1;
- const Scalar two = 2;
- const Scalar machep = cephes_helper<Scalar>::machep();
- const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
- const Scalar big = cephes_helper<Scalar>::big();
- const Scalar biginv = cephes_helper<Scalar>::biginv();
- const Scalar inf = NumTraits<Scalar>::infinity();
-
- Scalar ans, ax, c, yc, r, t, y, z;
- Scalar pk, pkm1, pkm2, qk, qkm1, qkm2;
-
- if (x == inf) return zero; // std::isinf crashes on CUDA
-
- /* Compute x**a * exp(-x) / gamma(a) */
- ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
- if (ax < -maxlog) { // underflow
- return zero;
- }
- ax = numext::exp(ax);
-
- // continued fraction
- y = one - a;
- z = x + y + one;
- c = zero;
- pkm2 = one;
- qkm2 = x;
- pkm1 = x + one;
- qkm1 = z * x;
- ans = pkm1 / qkm1;
-
- for (int i = 0; i < 2000; i++) {
- c += one;
- y += one;
- z += two;
- yc = y * c;
- pk = pkm1 * z - pkm2 * yc;
- qk = qkm1 * z - qkm2 * yc;
- if (qk != zero) {
- r = pk / qk;
- t = numext::abs((ans - r) / r);
- ans = r;
- } else {
- t = one;
- }
- pkm2 = pkm1;
- pkm1 = pk;
- qkm2 = qkm1;
- qkm1 = qk;
- if (numext::abs(pk) > big) {
- pkm2 *= biginv;
- pkm1 *= biginv;
- qkm2 *= biginv;
- qkm1 *= biginv;
- }
- if (t <= machep) {
- break;
- }
+ return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
}
- return (ans * ax);
+ return igammac_cf_impl<Scalar, VALUE>::run(a, x);
}
};
@@ -704,15 +812,10 @@ struct igammac_impl {
* Implementation of igamma (incomplete gamma integral), based on Cephes but requires C++11/C99 *
************************************************************************************************/
-template <typename Scalar>
-struct igamma_retval {
- typedef Scalar type;
-};
-
#if !EIGEN_HAS_C99_MATH
-template <typename Scalar>
-struct igamma_impl {
+template <typename Scalar, IgammaComputationMode mode>
+struct igamma_generic_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
@@ -723,69 +826,17 @@ struct igamma_impl {
#else
-template <typename Scalar>
-struct igamma_impl {
+template <typename Scalar, IgammaComputationMode mode>
+struct igamma_generic_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
- /* igam()
- * Incomplete gamma integral
- *
- *
- *
- * SYNOPSIS:
- *
- * double a, x, y, igam();
- *
- * y = igam( a, x );
- *
- * DESCRIPTION:
- *
- * The function is defined by
- *
- * x
- * -
- * 1 | | -t a-1
- * igam(a,x) = ----- | e t dt.
- * - | |
- * | (a) -
- * 0
- *
- *
- * In this implementation both arguments must be positive.
- * The integral is evaluated by either a power series or
- * continued fraction expansion, depending on the relative
- * values of a and x.
- *
- * ACCURACY (double):
- *
- * Relative error:
- * arithmetic domain # trials peak rms
- * IEEE 0,30 200000 3.6e-14 2.9e-15
- * IEEE 0,100 300000 9.9e-14 1.5e-14
- *
- *
- * ACCURACY (float):
- *
- * Relative error:
- * arithmetic domain # trials peak rms
- * IEEE 0,30 20000 7.8e-6 5.9e-7
- *
- */
- /*
- Cephes Math Library Release 2.2: June, 1992
- Copyright 1985, 1987, 1992 by Stephen L. Moshier
- Direct inquiries to 30 Frost Street, Cambridge, MA 02140
- */
-
-
- /* left tail of incomplete gamma function:
- *
- * inf. k
- * a -x - x
- * x e > ----------
- * - -
- * k=0 | (a+k+1)
+ /* Depending on the mode, returns
+ * - VALUE: incomplete Gamma function igamma(a, x)
+ * - DERIVATIVE: derivative of incomplete Gamma function d/da igamma(a, x)
+ * - SAMPLE_DERIVATIVE: implicit derivative of a Gamma random variable
+ * x ~ Gamma(x | a, 1), dx/da = -1 / Gamma(x | a, 1) * d igamma(a, x) / dx
*
+ * Derivatives are implemented by forward-mode differentiation.
*/
const Scalar zero = 0;
const Scalar one = 1;
@@ -797,71 +848,167 @@ struct igamma_impl {
return nan;
}
- if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
+ if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
return nan;
}
if ((x > one) && (x > a)) {
- /* The checks above ensure that we meet the preconditions for
- * igammac_impl::Impl(), so call it, rather than igammac_impl::Run().
- * Calling Run() would also work, but in that case the compiler may not be
- * able to prove that igammac_impl::Run and igamma_impl::Run are not
- * mutually recursive. This leads to worse code, particularly on
- * platforms like nvptx, where recursion is allowed only begrudgingly.
- */
- return (one - igammac_impl<Scalar>::Impl(a, x));
+ Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
+ if (mode == VALUE) {
+ return one - ret;
+ } else {
+ return -ret;
+ }
}
- return Impl(a, x);
+ return igamma_series_impl<Scalar, mode>::run(a, x);
}
+};
- private:
- /* igammac_impl calls igamma_impl::Impl. */
- friend struct igammac_impl<Scalar>;
+#endif // EIGEN_HAS_C99_MATH
- /* Actually computes igam(a, x).
+template <typename Scalar>
+struct igamma_retval {
+ typedef Scalar type;
+};
+
+template <typename Scalar>
+struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
+ /* igam()
+ * Incomplete gamma integral.
+ *
+ * The CDF of Gamma(a, 1) random variable at the point x.
+ *
+ * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
+ * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
+ * The ground truth is computed by mpmath. Mean absolute error:
+ * float: 1.26713e-05
+ * double: 2.33606e-12
+ *
+ * Cephes documentation below.
+ *
+ * SYNOPSIS:
+ *
+ * double a, x, y, igam();
+ *
+ * y = igam( a, x );
+ *
+ * DESCRIPTION:
+ *
+ * The function is defined by
+ *
+ * x
+ * -
+ * 1 | | -t a-1
+ * igam(a,x) = ----- | e t dt.
+ * - | |
+ * | (a) -
+ * 0
+ *
+ *
+ * In this implementation both arguments must be positive.
+ * The integral is evaluated by either a power series or
+ * continued fraction expansion, depending on the relative
+ * values of a and x.
+ *
+ * ACCURACY (double):
+ *
+ * Relative error:
+ * arithmetic domain # trials peak rms
+ * IEEE 0,30 200000 3.6e-14 2.9e-15
+ * IEEE 0,100 300000 9.9e-14 1.5e-14
+ *
+ *
+ * ACCURACY (float):
+ *
+ * Relative error:
+ * arithmetic domain # trials peak rms
+ * IEEE 0,30 20000 7.8e-6 5.9e-7
*
- * Preconditions:
- * x > 0
- * a > 0
- * !(x > 1 && x > a)
*/
- EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
- const Scalar zero = 0;
- const Scalar one = 1;
- const Scalar machep = cephes_helper<Scalar>::machep();
- const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
+ /*
+ Cephes Math Library Release 2.2: June, 1992
+ Copyright 1985, 1987, 1992 by Stephen L. Moshier
+ Direct inquiries to 30 Frost Street, Cambridge, MA 02140
+ */
- Scalar ans, ax, c, r;
+ /* left tail of incomplete gamma function:
+ *
+ * inf. k
+ * a -x - x
+ * x e > ----------
+ * - -
+ * k=0 | (a+k+1)
+ *
+ */
+};
- /* Compute x**a * exp(-x) / gamma(a) */
- ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
- if (ax < -maxlog) {
- // underflow
- return zero;
- }
- ax = numext::exp(ax);
+template <typename Scalar>
+struct igamma_der_a_retval : igamma_retval<Scalar> {};
- /* power series */
- r = a;
- c = one;
- ans = one;
+template <typename Scalar>
+struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
+ /* Derivative of the incomplete Gamma function with respect to a.
+ *
+ * Computes d/da igamma(a, x) by forward differentiation of the igamma code.
+ *
+ * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
+ * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
+ * The ground truth is computed by mpmath. Mean absolute error:
+ * float: 6.27648e-07
+ * double: 4.60455e-12
+ *
+ * Reference:
+ * R. Moore. "Algorithm AS 187: Derivatives of the incomplete gamma
+ * integral". Journal of the Royal Statistical Society. 1982
+ */
+};
- for (int i = 0; i < 2000; i++) {
- r += one;
- c *= x/r;
- ans += c;
- if (c/ans <= machep) {
- break;
- }
- }
+template <typename Scalar>
+struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
- return (ans * ax / a);
- }
+template <typename Scalar>
+struct gamma_sample_der_alpha_impl
+ : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
+ /* Derivative of a Gamma random variable sample with respect to alpha.
+ *
+ * Consider a sample of a Gamma random variable with the concentration
+ * parameter alpha: sample ~ Gamma(alpha, 1). The reparameterization
+ * derivative that we want to compute is dsample / dalpha =
+ * d igammainv(alpha, u) / dalpha, where u = igamma(alpha, sample).
+ * However, this formula is numerically unstable and expensive, so instead
+ * we use implicit differentiation:
+ *
+ * igamma(alpha, sample) = u, where u ~ Uniform(0, 1).
+ * Apply d / dalpha to both sides:
+ * d igamma(alpha, sample) / dalpha
+ * + d igamma(alpha, sample) / dsample * dsample/dalpha = 0
+ * d igamma(alpha, sample) / dalpha
+ * + Gamma(sample | alpha, 1) dsample / dalpha = 0
+ * dsample/dalpha = - (d igamma(alpha, sample) / dalpha)
+ * / Gamma(sample | alpha, 1)
+ *
+ * Here Gamma(sample | alpha, 1) is the PDF of the Gamma distribution
+ * (note that the derivative of the CDF w.r.t. sample is the PDF).
+ * See the reference below for more details.
+ *
+ * The derivative of igamma(alpha, sample) is computed by forward
+ * differentiation of the igamma code. Division by the Gamma PDF is performed
+ * in the same code, increasing the accuracy and speed due to cancellation
+ * of some terms.
+ *
+ * Accuracy estimation. For each alpha in [10^-2, 10^-1...10^3] we sample
+ * 50 Gamma random variables sample ~ Gamma(sample | alpha, 1), a total of 300
+ * points. The ground truth is computed by mpmath. Mean absolute error:
+ * float: 1.0993e-06
+ * double: 1.47631e-12
+ *
+ * Reference:
+ * M. Figurnov, S. Mohamed, A. Mnih "Implicit Reparameterization Gradients".
+ * 2018
+ */
};
-#endif // EIGEN_HAS_C99_MATH
-
/*****************************************************************************
* Implementation of Riemann zeta function of two arguments, based on Cephes *
*****************************************************************************/
@@ -1951,6 +2098,18 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar)
}
template <typename Scalar>
+EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar)
+ igamma_der_a(const Scalar& a, const Scalar& x) {
+ return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x);
+}
+
+template <typename Scalar>
+EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar)
+ gamma_sample_der_alpha(const Scalar& a, const Scalar& x) {
+ return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x);
+}
+
+template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
igammac(const Scalar& a, const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h
index 4c176716b..465f41d54 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h
@@ -42,6 +42,21 @@ Packet perfc(const Packet& a) { using numext::erfc; return erfc(a); }
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet pigamma(const Packet& a, const Packet& x) { using numext::igamma; return igamma(a, x); }
+/** \internal \returns the derivative of the incomplete gamma function
+ * igamma_der_a(\a a, \a x) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pigamma_der_a(const Packet& a, const Packet& x) {
+ using numext::igamma_der_a; return igamma_der_a(a, x);
+}
+
+/** \internal \returns compute the derivative of the sample
+ * of Gamma(alpha, 1) random variable with respect to the parameter a
+ * gamma_sample_der_alpha(\a alpha, \a sample) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pgamma_sample_der_alpha(const Packet& alpha, const Packet& sample) {
+ using numext::gamma_sample_der_alpha; return gamma_sample_der_alpha(alpha, sample);
+}
+
/** \internal \returns the complementary incomplete gamma function igammac(\a a, \a x) */
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet pigammac(const Packet& a, const Packet& x) { using numext::igammac; return igammac(a, x); }
diff --git a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
index c25fea0b3..020ac1b62 100644
--- a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
+++ b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
@@ -120,6 +120,41 @@ double2 pigamma<double2>(const double2& a, const double2& x)
return make_double2(igamma(a.x, x.x), igamma(a.y, x.y));
}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigamma_der_a<float4>(
+ const float4& a, const float4& x) {
+ using numext::igamma_der_a;
+ return make_float4(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y),
+ igamma_der_a(a.z, x.z), igamma_der_a(a.w, x.w));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pigamma_der_a<double2>(const double2& a, const double2& x) {
+ using numext::igamma_der_a;
+ return make_double2(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pgamma_sample_der_alpha<float4>(
+ const float4& alpha, const float4& sample) {
+ using numext::gamma_sample_der_alpha;
+ return make_float4(
+ gamma_sample_der_alpha(alpha.x, sample.x),
+ gamma_sample_der_alpha(alpha.y, sample.y),
+ gamma_sample_der_alpha(alpha.z, sample.z),
+ gamma_sample_der_alpha(alpha.w, sample.w));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pgamma_sample_der_alpha<double2>(const double2& alpha, const double2& sample) {
+ using numext::gamma_sample_der_alpha;
+ return make_double2(
+ gamma_sample_der_alpha(alpha.x, sample.x),
+ gamma_sample_der_alpha(alpha.y, sample.y));
+}
+
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 pigammac<float4>(const float4& a, const float4& x)
{
diff --git a/unsupported/test/cxx11_tensor_cuda.cu b/unsupported/test/cxx11_tensor_cuda.cu
index 63d0a345a..f238ed5be 100644
--- a/unsupported/test/cxx11_tensor_cuda.cu
+++ b/unsupported/test/cxx11_tensor_cuda.cu
@@ -1318,6 +1318,157 @@ void test_cuda_i1e()
cudaFree(d_out);
}
+template <typename Scalar>
+void test_cuda_igamma_der_a()
+{
+ Tensor<Scalar, 1> in_x(30);
+ Tensor<Scalar, 1> in_a(30);
+ Tensor<Scalar, 1> out(30);
+ Tensor<Scalar, 1> expected_out(30);
+ out.setZero();
+
+ Array<Scalar, 1, Dynamic> in_a_array(30);
+ Array<Scalar, 1, Dynamic> in_x_array(30);
+ Array<Scalar, 1, Dynamic> expected_out_array(30);
+
+ // See special_functions.cpp for the Python code that generates the test data.
+
+ in_a_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
+ 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
+
+ in_x_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
+ 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065,
+ 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288,
+ 1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458,
+ 7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233,
+ 92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677,
+ 968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
+
+ expected_out_array << -32.7256441441, -36.4394150514, -9.66467612263,
+ -36.4394150514, -36.4394150514, -1.0891900302, -2.66351229645,
+ -2.48666868596, -0.929700494428, -3.56327722764, -0.455320135314,
+ -0.391437214323, -0.491352055991, -0.350454834292, -0.471773162921,
+ -0.104084440522, -0.0723646747909, -0.0992828975532, -0.121638215446,
+ -0.122619605294, -0.0317670267286, -0.0359974812869, -0.0154359225363,
+ -0.0375775365921, -0.00794899153653, -0.00777303219211, -0.00796085782042,
+ -0.0125850719397, -0.00455500206958, -0.00476436993148;
+
+ for (int i = 0; i < 30; ++i) {
+ in_x(i) = in_x_array(i);
+ in_a(i) = in_a_array(i);
+ expected_out(i) = expected_out_array(i);
+ }
+
+ std::size_t bytes = in_x.size() * sizeof(Scalar);
+
+ Scalar* d_a;
+ Scalar* d_x;
+ Scalar* d_out;
+ cudaMalloc((void**)(&d_a), bytes);
+ cudaMalloc((void**)(&d_x), bytes);
+ cudaMalloc((void**)(&d_out), bytes);
+
+ cudaMemcpy(d_a, in_a.data(), bytes, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
+
+ Eigen::CudaStreamDevice stream;
+ Eigen::GpuDevice gpu_device(&stream);
+
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_a(d_a, 30);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_x(d_x, 30);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 30);
+
+ gpu_out.device(gpu_device) = gpu_a.igamma_der_a(gpu_x);
+
+ assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost,
+ gpu_device.stream()) == cudaSuccess);
+ assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
+
+ for (int i = 0; i < 30; ++i) {
+ VERIFY_IS_APPROX(out(i), expected_out(i));
+ }
+
+ cudaFree(d_a);
+ cudaFree(d_x);
+ cudaFree(d_out);
+}
+
+template <typename Scalar>
+void test_cuda_gamma_sample_der_alpha()
+{
+ Tensor<Scalar, 1> in_alpha(30);
+ Tensor<Scalar, 1> in_sample(30);
+ Tensor<Scalar, 1> out(30);
+ Tensor<Scalar, 1> expected_out(30);
+ out.setZero();
+
+ Array<Scalar, 1, Dynamic> in_alpha_array(30);
+ Array<Scalar, 1, Dynamic> in_sample_array(30);
+ Array<Scalar, 1, Dynamic> expected_out_array(30);
+
+ // See special_functions.cpp for the Python code that generates the test data.
+
+ in_alpha_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0,
+ 1.0, 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0,
+ 100.0, 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
+
+ in_sample_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
+ 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065,
+ 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288,
+ 1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458,
+ 7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233,
+ 92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677,
+ 968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
+
+ expected_out_array << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
+ 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
+ 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
+ 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
+ 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
+ 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
+ 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
+ 1.00106492525, 0.97734200649, 1.02198794179;
+
+ for (int i = 0; i < 30; ++i) {
+ in_alpha(i) = in_alpha_array(i);
+ in_sample(i) = in_sample_array(i);
+ expected_out(i) = expected_out_array(i);
+ }
+
+ std::size_t bytes = in_alpha.size() * sizeof(Scalar);
+
+ Scalar* d_alpha;
+ Scalar* d_sample;
+ Scalar* d_out;
+ cudaMalloc((void**)(&d_alpha), bytes);
+ cudaMalloc((void**)(&d_sample), bytes);
+ cudaMalloc((void**)(&d_out), bytes);
+
+ cudaMemcpy(d_alpha, in_alpha.data(), bytes, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sample, in_sample.data(), bytes, cudaMemcpyHostToDevice);
+
+ Eigen::CudaStreamDevice stream;
+ Eigen::GpuDevice gpu_device(&stream);
+
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_alpha(d_alpha, 30);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_sample(d_sample, 30);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 30);
+
+ gpu_out.device(gpu_device) = gpu_alpha.gamma_sample_der_alpha(gpu_sample);
+
+ assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost,
+ gpu_device.stream()) == cudaSuccess);
+ assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
+
+ for (int i = 0; i < 30; ++i) {
+ VERIFY_IS_APPROX(out(i), expected_out(i));
+ }
+
+ cudaFree(d_alpha);
+ cudaFree(d_sample);
+ cudaFree(d_out);
+}
void test_cxx11_tensor_cuda()
{
@@ -1396,5 +1547,11 @@ void test_cxx11_tensor_cuda()
CALL_SUBTEST_6(test_cuda_i1e<float>());
CALL_SUBTEST_6(test_cuda_i1e<double>());
+
+ CALL_SUBTEST_6(test_cuda_igamma_der_a<float>());
+ CALL_SUBTEST_6(test_cuda_igamma_der_a<double>());
+
+ CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha<float>());
+ CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha<double>());
#endif
}
diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp
index 48d0db95e..29ba6203a 100644
--- a/unsupported/test/special_functions.cpp
+++ b/unsupported/test/special_functions.cpp
@@ -375,6 +375,98 @@ template<typename ArrayType> void array_special_functions()
CALL_SUBTEST(res = i1e(x);
verify_component_wise(res, expected););
}
+
+ /* Code to generate the data for the following two test cases.
+ N = 5
+ np.random.seed(3)
+
+ a = np.logspace(-2, 3, 6)
+ a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N]))
+ x = np.random.gamma(a, 1.0)
+ x = np.maximum(x, np.finfo(np.float32).tiny)
+
+ def igamma(a, x):
+ return mpmath.gammainc(a, 0, x, regularized=True)
+
+ def igamma_der_a(a, x):
+ res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a)
+ return np.float64(res)
+
+ def gamma_sample_der_alpha(a, x):
+ igamma_x = igamma(a, x)
+ def igammainv_of_igamma(a_prime):
+ return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) -
+ igamma_x, x, solver='newton')
+ return np.float64(mpmath.diff(igammainv_of_igamma, a))
+
+ v_igamma_der_a = np.vectorize(igamma_der_a)(a, x)
+ v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x)
+ */
+
+ // Test igamma_der_a
+ {
+ ArrayType a(30);
+ ArrayType x(30);
+ ArrayType res(30);
+ ArrayType v(30);
+
+ a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
+ 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
+
+ x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
+ 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
+ 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
+ 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
+ 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
+ 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
+ 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
+ 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
+
+ v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514,
+ -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596,
+ -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323,
+ -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522,
+ -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294,
+ -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921,
+ -0.00794899153653, -0.00777303219211, -0.00796085782042,
+ -0.0125850719397, -0.00455500206958, -0.00476436993148;
+
+ CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v););
+ }
+
+ // Test gamma_sample_der_alpha
+ {
+ ArrayType alpha(30);
+ ArrayType sample(30);
+ ArrayType res(30);
+ ArrayType v(30);
+
+ alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
+ 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
+
+ sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
+ 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
+ 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
+ 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
+ 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
+ 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
+ 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
+ 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
+
+ v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
+ 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
+ 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
+ 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
+ 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
+ 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
+ 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
+ 1.00106492525, 0.97734200649, 1.02198794179;
+
+ CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample);
+ verify_component_wise(res, v););
+ }
#endif
}