From e2f21465fea76a80966f12a20d0be36597f19b44 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 2 Dec 2020 14:00:57 -0800 Subject: Special function implementations for half/bfloat16 packets. Current implementations fail to consider half-float packets, only half-float scalars. Added specializations for packets on AVX, AVX512 and NEON. Added tests to `special_packetmath`. The current `special_functions` tests would fail for half and bfloat16 due to lack of precision. The NEON tests also fail with precision issues and due to different handling of `sqrt(inf)`, so special functions bessel, ndtri have been disabled. Tested with AVX, AVX512. --- unsupported/test/special_functions.cpp | 89 +++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 33 deletions(-) (limited to 'unsupported/test/special_functions.cpp') diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index 1027272d5..56848daa5 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -11,6 +11,17 @@ #include "main.h" #include "../Eigen/SpecialFunctions" +// Hack to allow "implicit" conversions from double to Scalar via comma-initialization. +template +Eigen::CommaInitializer operator<<(Eigen::DenseBase& dense, double v) { + return (dense << static_cast(v)); +} + +template +Eigen::CommaInitializer& operator,(Eigen::CommaInitializer& ci, double v) { + return (ci, static_cast(v)); +} + template void verify_component_wise(const X& x, const Y& y) { @@ -65,8 +76,8 @@ template void array_special_functions() // igamma(a, x) = gamma(a, x) / Gamma(a) // where Gamma and gamma are considered the standard unnormalized // upper and lower incomplete gamma functions, respectively. - ArrayType a = m1.abs() + 2; - ArrayType x = m2.abs() + 2; + ArrayType a = m1.abs() + Scalar(2); + ArrayType x = m2.abs() + Scalar(2); ArrayType zero = ArrayType::Zero(rows, cols); ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0)); ArrayType a_m1 = a - one; @@ -83,18 +94,18 @@ template void array_special_functions() VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp()); // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x) - VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp()); + VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp()); // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x) - VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp()); + VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp()); } { // Verify for large a and x that values are between 0 and 1. ArrayType m1 = ArrayType::Random(rows,cols); ArrayType m2 = ArrayType::Random(rows,cols); - Scalar max_exponent = std::numeric_limits::max_exponent10; - ArrayType a = m1.abs() * pow(10., max_exponent - 1); - ArrayType x = m2.abs() * pow(10., max_exponent - 1); + int max_exponent = std::numeric_limits::max_exponent10; + ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1)); + ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1)); for (int i = 0; i < a.size(); ++i) { Scalar igam = numext::igamma(a(i), x(i)); VERIFY(0 <= igam); @@ -108,27 +119,37 @@ template void array_special_functions() Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)}; // location i*6+j corresponds to a_s[i], x_s[j]. - Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan}, - {0.0, 0.6321205588285578, 0.7768698398515702, - 0.9816843611112658, 9.999500016666262e-05, 1.0}, - {0.0, 0.4275932955291202, 0.608374823728911, - 0.9539882943107686, 7.522076445089201e-07, 1.0}, - {0.0, 0.01898815687615381, 0.06564245437845008, - 0.5665298796332909, 4.166333347221828e-18, 1.0}, - {0.0, 0.9999780593618628, 0.9999899967080838, - 0.9999996219837988, 0.9991370418689945, 1.0}, - {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}}; - Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan}, - {1.0, 0.36787944117144233, 0.22313016014842982, - 0.018315638888734182, 0.9999000049998333, 0.0}, - {1.0, 0.5724067044708798, 0.3916251762710878, - 0.04601170568923136, 0.9999992477923555, 0.0}, - {1.0, 0.9810118431238462, 0.9343575456215499, - 0.4334701203667089, 1.0, 0.0}, - {1.0, 2.1940638138146658e-05, 1.0003291916285e-05, - 3.7801620118431334e-07, 0.0008629581310054535, - 0.0}, - {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}}; + Scalar igamma_s[][6] = { + {Scalar(0.0), nan, nan, nan, nan, nan}, + {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702), + Scalar(0.9816843611112658), Scalar(9.999500016666262e-05), + Scalar(1.0)}, + {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911), + Scalar(0.9539882943107686), Scalar(7.522076445089201e-07), + Scalar(1.0)}, + {Scalar(0.0), Scalar(0.01898815687615381), + Scalar(0.06564245437845008), Scalar(0.5665298796332909), + Scalar(4.166333347221828e-18), Scalar(1.0)}, + {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838), + Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)}, + {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), + Scalar(0.5042041932513908)}}; + Scalar igammac_s[][6] = { + {nan, nan, nan, nan, nan, nan}, + {Scalar(1.0), Scalar(0.36787944117144233), + Scalar(0.22313016014842982), Scalar(0.018315638888734182), + Scalar(0.9999000049998333), Scalar(0.0)}, + {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878), + Scalar(0.04601170568923136), Scalar(0.9999992477923555), + Scalar(0.0)}, + {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499), + Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)}, + {Scalar(1.0), Scalar(2.1940638138146658e-05), + Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07), + Scalar(0.0008629581310054535), Scalar(0.0)}, + {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), + Scalar(0.49579580674813944)}}; + for (int i = 0; i < 6; ++i) { for (int j = 0; j < 6; ++j) { if ((std::isnan)(igamma_s[i][j])) { @@ -162,8 +183,8 @@ template void array_special_functions() ArrayType m1 = ArrayType::Random(32); using std::sqrt; - ArrayType cdf_val = (m1 / sqrt(2.)).erf(); - cdf_val = (cdf_val + 1.) / 2.; + ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf(); + cdf_val = (cdf_val + Scalar(1)) / Scalar(2); verify_component_wise(cdf_val.ndtri(), m1);); } @@ -190,7 +211,6 @@ template void array_special_functions() CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); ); } - #if EIGEN_HAS_C99_MATH { ArrayType n(11), x(11), res(11), ref(11); @@ -323,8 +343,8 @@ template void array_special_functions() ArrayType m3 = ArrayType::Random(32); ArrayType one = ArrayType::Constant(32, Scalar(1.0)); const Scalar eps = std::numeric_limits::epsilon(); - ArrayType a = (m1 * 4.0).exp(); - ArrayType b = (m2 * 4.0).exp(); + ArrayType a = (m1 * Scalar(4)).exp(); + ArrayType b = (m2 * Scalar(4)).exp(); ArrayType x = m3.abs(); // betainc(a, 1, x) == x**a @@ -471,4 +491,7 @@ EIGEN_DECLARE_TEST(special_functions) { CALL_SUBTEST_1(array_special_functions()); CALL_SUBTEST_2(array_special_functions()); + // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above. + // CALL_SUBTEST_3(array_special_functions>()); + // CALL_SUBTEST_4(array_special_functions>()); } -- cgit v1.2.3