aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/special_functions.cpp
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-02 14:00:57 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2020-12-04 10:16:29 -0800
commite2f21465fea76a80966f12a20d0be36597f19b44 (patch)
tree1ae9b0e3ae489b028902166a343f796d196fde82 /unsupported/test/special_functions.cpp
parent305b8bd2777bda99f65791468f305b76021bf579 (diff)
Special function implementations for half/bfloat16 packets.
Current implementations fail to consider half-float packets, only half-float scalars. Added specializations for packets on AVX, AVX512 and NEON. Added tests to `special_packetmath`. The current `special_functions` tests would fail for half and bfloat16 due to lack of precision. The NEON tests also fail with precision issues and due to different handling of `sqrt(inf)`, so special functions bessel, ndtri have been disabled. Tested with AVX, AVX512.
Diffstat (limited to 'unsupported/test/special_functions.cpp')
-rw-r--r--unsupported/test/special_functions.cpp89
1 files changed, 56 insertions, 33 deletions
diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp
index 1027272d5..56848daa5 100644
--- a/unsupported/test/special_functions.cpp
+++ b/unsupported/test/special_functions.cpp
@@ -11,6 +11,17 @@
#include "main.h"
#include "../Eigen/SpecialFunctions"
+// Hack to allow "implicit" conversions from double to Scalar via comma-initialization.
+template<typename Derived>
+Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) {
+ return (dense << static_cast<typename Derived::Scalar>(v));
+}
+
+template<typename XprType>
+Eigen::CommaInitializer<XprType>& operator,(Eigen::CommaInitializer<XprType>& ci, double v) {
+ return (ci, static_cast<typename XprType::Scalar>(v));
+}
+
template<typename X, typename Y>
void verify_component_wise(const X& x, const Y& y)
{
@@ -65,8 +76,8 @@ template<typename ArrayType> void array_special_functions()
// igamma(a, x) = gamma(a, x) / Gamma(a)
// where Gamma and gamma are considered the standard unnormalized
// upper and lower incomplete gamma functions, respectively.
- ArrayType a = m1.abs() + 2;
- ArrayType x = m2.abs() + 2;
+ ArrayType a = m1.abs() + Scalar(2);
+ ArrayType x = m2.abs() + Scalar(2);
ArrayType zero = ArrayType::Zero(rows, cols);
ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
ArrayType a_m1 = a - one;
@@ -83,18 +94,18 @@ template<typename ArrayType> void array_special_functions()
VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
// Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
- VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
+ VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp());
// gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
- VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
+ VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp());
}
{
// Verify for large a and x that values are between 0 and 1.
ArrayType m1 = ArrayType::Random(rows,cols);
ArrayType m2 = ArrayType::Random(rows,cols);
- Scalar max_exponent = std::numeric_limits<Scalar>::max_exponent10;
- ArrayType a = m1.abs() * pow(10., max_exponent - 1);
- ArrayType x = m2.abs() * pow(10., max_exponent - 1);
+ int max_exponent = std::numeric_limits<Scalar>::max_exponent10;
+ ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1));
+ ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1));
for (int i = 0; i < a.size(); ++i) {
Scalar igam = numext::igamma(a(i), x(i));
VERIFY(0 <= igam);
@@ -108,27 +119,37 @@ template<typename ArrayType> void array_special_functions()
Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
// location i*6+j corresponds to a_s[i], x_s[j].
- Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
- {0.0, 0.6321205588285578, 0.7768698398515702,
- 0.9816843611112658, 9.999500016666262e-05, 1.0},
- {0.0, 0.4275932955291202, 0.608374823728911,
- 0.9539882943107686, 7.522076445089201e-07, 1.0},
- {0.0, 0.01898815687615381, 0.06564245437845008,
- 0.5665298796332909, 4.166333347221828e-18, 1.0},
- {0.0, 0.9999780593618628, 0.9999899967080838,
- 0.9999996219837988, 0.9991370418689945, 1.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
- Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
- {1.0, 0.36787944117144233, 0.22313016014842982,
- 0.018315638888734182, 0.9999000049998333, 0.0},
- {1.0, 0.5724067044708798, 0.3916251762710878,
- 0.04601170568923136, 0.9999992477923555, 0.0},
- {1.0, 0.9810118431238462, 0.9343575456215499,
- 0.4334701203667089, 1.0, 0.0},
- {1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
- 3.7801620118431334e-07, 0.0008629581310054535,
- 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
+ Scalar igamma_s[][6] = {
+ {Scalar(0.0), nan, nan, nan, nan, nan},
+ {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702),
+ Scalar(0.9816843611112658), Scalar(9.999500016666262e-05),
+ Scalar(1.0)},
+ {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911),
+ Scalar(0.9539882943107686), Scalar(7.522076445089201e-07),
+ Scalar(1.0)},
+ {Scalar(0.0), Scalar(0.01898815687615381),
+ Scalar(0.06564245437845008), Scalar(0.5665298796332909),
+ Scalar(4.166333347221828e-18), Scalar(1.0)},
+ {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838),
+ Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)},
+ {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0),
+ Scalar(0.5042041932513908)}};
+ Scalar igammac_s[][6] = {
+ {nan, nan, nan, nan, nan, nan},
+ {Scalar(1.0), Scalar(0.36787944117144233),
+ Scalar(0.22313016014842982), Scalar(0.018315638888734182),
+ Scalar(0.9999000049998333), Scalar(0.0)},
+ {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878),
+ Scalar(0.04601170568923136), Scalar(0.9999992477923555),
+ Scalar(0.0)},
+ {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499),
+ Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)},
+ {Scalar(1.0), Scalar(2.1940638138146658e-05),
+ Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07),
+ Scalar(0.0008629581310054535), Scalar(0.0)},
+ {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0),
+ Scalar(0.49579580674813944)}};
+
for (int i = 0; i < 6; ++i) {
for (int j = 0; j < 6; ++j) {
if ((std::isnan)(igamma_s[i][j])) {
@@ -162,8 +183,8 @@ template<typename ArrayType> void array_special_functions()
ArrayType m1 = ArrayType::Random(32);
using std::sqrt;
- ArrayType cdf_val = (m1 / sqrt(2.)).erf();
- cdf_val = (cdf_val + 1.) / 2.;
+ ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf();
+ cdf_val = (cdf_val + Scalar(1)) / Scalar(2);
verify_component_wise(cdf_val.ndtri(), m1););
}
@@ -190,7 +211,6 @@ template<typename ArrayType> void array_special_functions()
CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
}
-
#if EIGEN_HAS_C99_MATH
{
ArrayType n(11), x(11), res(11), ref(11);
@@ -323,8 +343,8 @@ template<typename ArrayType> void array_special_functions()
ArrayType m3 = ArrayType::Random(32);
ArrayType one = ArrayType::Constant(32, Scalar(1.0));
const Scalar eps = std::numeric_limits<Scalar>::epsilon();
- ArrayType a = (m1 * 4.0).exp();
- ArrayType b = (m2 * 4.0).exp();
+ ArrayType a = (m1 * Scalar(4)).exp();
+ ArrayType b = (m2 * Scalar(4)).exp();
ArrayType x = m3.abs();
// betainc(a, 1, x) == x**a
@@ -471,4 +491,7 @@ EIGEN_DECLARE_TEST(special_functions)
{
CALL_SUBTEST_1(array_special_functions<ArrayXf>());
CALL_SUBTEST_2(array_special_functions<ArrayXd>());
+ // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above.
+ // CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>());
+ // CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>());
}