aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
diff options
context:
space:
mode:
authorGravatar Michael Figurnov <mfigurnov@google.com>2018-06-06 18:49:26 +0100
committerGravatar Michael Figurnov <mfigurnov@google.com>2018-06-06 18:49:26 +0100
commit4bd158fa37b4bba74e6421575d5c69eeea547172 (patch)
tree940bd1497831563a1792aea863ce9e2a9afd0b45 /unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
parente206f8d4a401fe2060bada4d4b5d92e3bf3b561c (diff)
Derivative of the incomplete Gamma function and the sample of a Gamma random variable.
In addition to igamma(a, x), this code implements: * igamma_der_a(a, x) = d igamma(a, x) / da -- derivative of igamma with respect to the parameter * gamma_sample_der_alpha(alpha, sample) -- reparameterization derivative of a Gamma(alpha, 1) random variable sample with respect to the alpha parameter The derivatives are computed by forward mode differentiation of the igamma(a, x) code. Although gamma_sample_der_alpha can be implemented via igamma_der_a, a separate function is more accurate and efficient due to analytical cancellation of some terms. All three functions are implemented by a method parameterized with "mode" that always computes the derivatives, but does not return them unless required by the mode. The compiler is expected to (and, based on benchmarks, does) skip the unnecessary computations depending on the mode.
Diffstat (limited to 'unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h')
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h35
1 files changed, 35 insertions, 0 deletions
diff --git a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
index c25fea0b3..020ac1b62 100644
--- a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
+++ b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h
@@ -120,6 +120,41 @@ double2 pigamma<double2>(const double2& a, const double2& x)
return make_double2(igamma(a.x, x.x), igamma(a.y, x.y));
}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigamma_der_a<float4>(
+ const float4& a, const float4& x) {
+ using numext::igamma_der_a;
+ return make_float4(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y),
+ igamma_der_a(a.z, x.z), igamma_der_a(a.w, x.w));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pigamma_der_a<double2>(const double2& a, const double2& x) {
+ using numext::igamma_der_a;
+ return make_double2(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pgamma_sample_der_alpha<float4>(
+ const float4& alpha, const float4& sample) {
+ using numext::gamma_sample_der_alpha;
+ return make_float4(
+ gamma_sample_der_alpha(alpha.x, sample.x),
+ gamma_sample_der_alpha(alpha.y, sample.y),
+ gamma_sample_der_alpha(alpha.z, sample.z),
+ gamma_sample_der_alpha(alpha.w, sample.w));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pgamma_sample_der_alpha<double2>(const double2& alpha, const double2& sample) {
+ using numext::gamma_sample_der_alpha;
+ return make_double2(
+ gamma_sample_der_alpha(alpha.x, sample.x),
+ gamma_sample_der_alpha(alpha.y, sample.y));
+}
+
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 pigammac<float4>(const float4& a, const float4& x)
{