diff options
author | 2016-06-02 17:04:19 -0700 | |
---|---|---|
committer | 2016-06-02 17:04:19 -0700 | |
commit | 39baff850c2f4fe1fee3b7a3918ba62a526e4f08 (patch) | |
tree | 841ea12578450cfc0ab3e96a68d4e433a985a01d /test | |
parent | 02db4e1a82e6059cc217d6aa57bcc5ac6342eb37 (diff) |
Add TernaryFunctors and the betainc SpecialFunction.
TernaryFunctors and their executors allow operations on 3-tuples of inputs.
API fully implemented for Arrays and Tensors based on binary functors.
Ported the cephes betainc function (regularized incomplete beta
integral) to Eigen, with support for CPU and GPU, floats, doubles, and
half types.
Added unit tests in array.cpp and cxx11_tensor_cuda.cu
Collapsed revision
* Merged helper methods for betainc across floats and doubles.
* Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase.
* Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper.
* betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup.
* Update TernaryOp and SpecialFunctions (betainc) based on review comments.
Diffstat (limited to 'test')
-rw-r--r-- | test/array.cpp | 113 |
1 files changed, 110 insertions, 3 deletions
diff --git a/test/array.cpp b/test/array.cpp index 39a7b856f..8ed1269c2 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -592,16 +592,123 @@ template<typename ArrayType> void array_special_functions() ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927; CALL_SUBTEST( verify_component_wise(ref, ref); ); - if(sizeof(RealScalar)>=64) { -// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); ); + if(sizeof(RealScalar)>=8) { // double + // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232 + // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); ); CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); ); } else { -// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); ); + // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); ); CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); ); } } #endif + +#if EIGEN_HAS_C99_MATH + { + // Inputs and ground truth generated with scipy via: + // a = np.logspace(-3, 3, 5) - 1e-3 + // b = np.logspace(-3, 3, 5) - 1e-3 + // x = np.linspace(-0.1, 1.1, 5) + // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x)) + // full_a = full_a.flatten().tolist() # same for full_b, full_x + // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist() + // + // Note in Eigen, we call betainc with arguments in the order (x, a, b). + ArrayType a(125); + ArrayType b(125); + ArrayType x(125); + ArrayType v(125); + ArrayType res(125); + + a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, + 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, + 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, + 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, + 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, + 999.999, 999.999, 999.999; + + b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999, + 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, + 999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, + 0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, + 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, + 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, + 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, + 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, + 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, + 31.62177660168379, 31.62177660168379, 31.62177660168379, + 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, + 999.999, 999.999; + + x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, + 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, + 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, + 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, + -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, + 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, + 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, + 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, + 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, + 0.8, 1.1; + + v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan, + 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan, + 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan, + 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan, + nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256, + 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001, + 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403, + 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999, + 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan, + 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06, + nan, nan, 7.864342668429763e-23, 3.015969667594166e-10, + 0.0008598571564165444, nan, nan, 6.031987710123844e-08, + 0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999, + 0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan, + nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan, + 0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0, + 3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan, + 2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan; + + CALL_SUBTEST(res = betainc(a, b, x); + verify_component_wise(res, v);); + } +#endif } void test_array() |