From 386d809bde475c65b7940f290efe80e6a05878c4 Mon Sep 17 00:00:00 2001 From: Teng Lu Date: Sat, 20 Jun 2020 19:16:24 +0000 Subject: Support BFloat16 in Eigen --- unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h | 11 ++++ unsupported/Eigen/SpecialFunctions | 2 + .../src/SpecialFunctions/BesselFunctionsBFloat16.h | 68 ++++++++++++++++++++++ .../SpecialFunctions/SpecialFunctionsBFloat16.h | 58 ++++++++++++++++++ unsupported/test/cxx11_tensor_reduction.cpp | 1 + 5 files changed, 140 insertions(+) create mode 100644 unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h create mode 100644 unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h (limited to 'unsupported') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h index 445248163..ea286fee1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h @@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform(uint64_t* state, uint64_t stream) { return result - Eigen::half(1.0f); } +template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Eigen::bfloat16 RandomToTypeUniform(uint64_t* state, uint64_t stream) { + Eigen::bfloat16 result; + // Generate 7 random bits for the mantissa + unsigned rnd = PCG_XSH_RS_generator(state, stream); + result.value = static_cast(rnd & 0x7fu); + // Set the exponent + result.value |= (static_cast(127) << 7); + // Return the final result + return result - Eigen::bfloat16(1.0f); +} template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float RandomToTypeUniform(uint64_t* state, uint64_t stream) { diff --git a/unsupported/Eigen/SpecialFunctions b/unsupported/Eigen/SpecialFunctions index a098ce871..dda6618de 100644 --- a/unsupported/Eigen/SpecialFunctions +++ b/unsupported/Eigen/SpecialFunctions @@ -62,6 +62,7 @@ namespace Eigen { #include "src/SpecialFunctions/BesselFunctionsImpl.h" #include "src/SpecialFunctions/BesselFunctionsPacketMath.h" +#include "src/SpecialFunctions/BesselFunctionsBFloat16.h" #include "src/SpecialFunctions/BesselFunctionsHalf.h" #include "src/SpecialFunctions/BesselFunctionsFunctors.h" #include "src/SpecialFunctions/BesselFunctionsArrayAPI.h" @@ -70,6 +71,7 @@ namespace Eigen { #include "src/SpecialFunctions/HipVectorCompatibility.h" #endif #include "src/SpecialFunctions/SpecialFunctionsPacketMath.h" +#include "src/SpecialFunctions/SpecialFunctionsBFloat16.h" #include "src/SpecialFunctions/SpecialFunctionsHalf.h" #include "src/SpecialFunctions/SpecialFunctionsFunctors.h" #include "src/SpecialFunctions/SpecialFunctionsArrayAPI.h" diff --git a/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h new file mode 100644 index 000000000..6049cc2fe --- /dev/null +++ b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h @@ -0,0 +1,68 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_BESSELFUNCTIONS_BFLOAT16_H +#define EIGEN_BESSELFUNCTIONS_BFLOAT16_H + +namespace Eigen { +namespace numext { + +#if EIGEN_HAS_C99_MATH +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i0e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i1e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_j0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_j1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_y0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_y1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k0e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k1e(static_cast(x))); +} +#endif + +} // end namespace numext +} // end namespace Eigen + +#endif // EIGEN_BESSELFUNCTIONS_BFLOAT16_H diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h new file mode 100644 index 000000000..2d94231f0 --- /dev/null +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h @@ -0,0 +1,58 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_SPECIALFUNCTIONS_BFLOAT16_H +#define EIGEN_SPECIALFUNCTIONS_BFLOAT16_H + +namespace Eigen { +namespace numext { + +#if EIGEN_HAS_C99_MATH +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 lgamma(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::lgamma(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 digamma(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::digamma(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 zeta(const Eigen::bfloat16& x, const Eigen::bfloat16& q) { + return Eigen::bfloat16(Eigen::numext::zeta(static_cast(x), static_cast(q))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 polygamma(const Eigen::bfloat16& n, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::polygamma(static_cast(n), static_cast(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erf(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::erf(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erfc(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::erfc(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 ndtri(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::ndtri(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igamma(static_cast(a), static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma_der_a(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igamma_der_a(static_cast(a), static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 gamma_sample_der_alpha(const Eigen::bfloat16& alpha, const Eigen::bfloat16& sample) { + return Eigen::bfloat16(Eigen::numext::gamma_sample_der_alpha(static_cast(alpha), static_cast(sample))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igammac(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igammac(static_cast(a), static_cast(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 betainc(const Eigen::bfloat16& a, const Eigen::bfloat16& b, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::betainc(static_cast(a), static_cast(b), static_cast(x))); +} +#endif + +} // end namespace numext +} // end namespace Eigen + +#endif // EIGEN_SPECIALFUNCTIONS_BFLOAT16_H diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index 996dba806..f1ac83b1b 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -511,6 +511,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_reduction) { CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(( test_simple_reductions() )); + CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(test_reductions_in_expr()); CALL_SUBTEST(test_reductions_in_expr()); CALL_SUBTEST(test_full_reductions()); -- cgit v1.2.3