diff options
-rw-r--r-- | Eigen/src/Core/arch/AVX/MathFunctions.h | 10 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX/PacketMath.h | 114 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/Half.h | 7 | ||||
-rw-r--r-- | test/packetmath.cpp | 27 |
4 files changed, 120 insertions, 38 deletions
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 9b123db00..e2e704d82 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -158,6 +158,16 @@ Packet4d prsqrt<Packet4d>(const Packet4d& _x) { return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x)); } +F16_PACKET_FUNCTION(Packet8f, Packet8h, psin) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos) +F16_PACKET_FUNCTION(Packet8f, Packet8h, plog) +F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp) +F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) +F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) +F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) + BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index b68351356..e9eaaa9e0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -119,22 +119,34 @@ struct packet_traits<Eigen::half> : default_packet_traits { AlignedOnScalar = 1, size = 8, HasHalfPacket = 0, + + HasCmp = 1, HasAdd = 1, HasSub = 1, HasMul = 1, HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, HasNegate = 1, - HasAbs = 0, + HasAbs = 1, HasAbs2 = 0, - HasMin = 0, - HasMax = 0, - HasConj = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasSetLinear = 0, - HasSqrt = 0, - HasRsqrt = 0, - HasExp = 0, - HasLog = 0, - HasBlend = 0 + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -150,16 +162,24 @@ struct packet_traits<bfloat16> : default_packet_traits { size = 8, HasHalfPacket = 0, - HasCmp = 1, + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasMul = 1, HasDiv = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, + HasNegate = 1, + HasAbs = 1, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, HasLog = 1, HasLog1p = 1, HasExpm1 = 1, HasExp = 1, - HasNdtri = 1, - HasBessel = 1, HasSqrt = 1, HasRsqrt = 1, HasTanh = EIGEN_FAST_MATH, @@ -870,8 +890,7 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons } // Packet math for Eigen::half -// TODO(cantonios): add missing packet ops -// - pabs, pmin, pmax, plset, pround, print, pceil, pfloor, pcmp_lt, pcmp_le, pcmp_lt_or_nan + template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; }; template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) { @@ -914,6 +933,16 @@ ploadquad<Packet8h>(const Eigen::half* from) { return _mm_set_epi16(b, b, b, b, a, a, a, a); } +template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) { + return _mm_cmpeq_epi32(a, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) { + const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm_andnot_si128(sign_mask, a); +} + EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { #ifdef EIGEN_HAS_FP16_C return _mm256_cvtph_ps(a); @@ -951,8 +980,21 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { #endif } -template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) { - return _mm_cmpeq_epi32(a, a); +template <> +EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a, + const Packet8h& b) { + return float2half(pmin<Packet8f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a, + const Packet8h& b) { + return float2half(pmax<Packet8f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) { + return float2half(plset<Packet8f>(static_cast<float>(a))); } template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) { @@ -974,13 +1016,36 @@ template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Pack return _mm_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) { + return float2half(pround<Packet8f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) { + return float2half(print<Packet8f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) { + return float2half(pceil<Packet8f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) { + return float2half(pfloor<Packet8f>(half2float(a))); +} + template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) { - Packet8f af = half2float(a); - Packet8f bf = half2float(b); - Packet8f rf = pcmp_eq(af, bf); - // Pack the 32-bit flags into 16-bits flags. - return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), - _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); + return Pack16To8(pcmp_eq(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b))); } template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } @@ -1148,6 +1213,8 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) { kernel.packet[3] = pload<Packet8h>(out[3]); } +// BFloat16 implementation. + EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { #ifdef EIGEN_VECTORIZE_AVX2 __m256i extend = _mm256_cvtepu16_epi32(a); @@ -1262,7 +1329,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) { template <> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { - return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a))); + const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm_andnot_si128(sign_mask, a); } template <> diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index d6ddff59c..5166f54c7 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -56,6 +56,13 @@ #define EIGEN_CONSTEXPR #endif +#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \ + template <> \ + EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \ + PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \ + return float2half(METHOD<PACKET_F>(half2float(_x))); \ + } + namespace Eigen { struct half; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index ae21dda5c..d52f997dc 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -556,7 +556,7 @@ void packetmath_real() { VERIFY((numext::isnan)(data2[0])); // TODO(rmlarsen): Re-enable for bfloat16. if (!internal::is_same<Scalar, bfloat16>::value) { - VERIFY_IS_EQUAL(std::exp(small), data2[1]); + VERIFY_IS_APPROX(std::exp(small), data2[1]); } data1[0] = -small; @@ -564,21 +564,21 @@ void packetmath_real() { h.store(data2, internal::pexp(h.load(data1))); // TODO(rmlarsen): Re-enable for bfloat16. if (!internal::is_same<Scalar, bfloat16>::value) { - VERIFY_IS_EQUAL(std::exp(-small), data2[0]); + VERIFY_IS_APPROX(std::exp(-small), data2[0]); } VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]); data1[0] = (std::numeric_limits<Scalar>::min)(); data1[1] = -(std::numeric_limits<Scalar>::min)(); h.store(data2, internal::pexp(h.load(data1))); - VERIFY_IS_EQUAL(std::exp((std::numeric_limits<Scalar>::min)()), data2[0]); - VERIFY_IS_EQUAL(std::exp(-(std::numeric_limits<Scalar>::min)()), data2[1]); + VERIFY_IS_APPROX(std::exp((std::numeric_limits<Scalar>::min)()), data2[0]); + VERIFY_IS_APPROX(std::exp(-(std::numeric_limits<Scalar>::min)()), data2[1]); data1[0] = std::numeric_limits<Scalar>::denorm_min(); data1[1] = -std::numeric_limits<Scalar>::denorm_min(); h.store(data2, internal::pexp(h.load(data1))); - VERIFY_IS_EQUAL(std::exp(std::numeric_limits<Scalar>::denorm_min()), data2[0]); - VERIFY_IS_EQUAL(std::exp(-std::numeric_limits<Scalar>::denorm_min()), data2[1]); + VERIFY_IS_APPROX(std::exp(std::numeric_limits<Scalar>::denorm_min()), data2[0]); + VERIFY_IS_APPROX(std::exp(-std::numeric_limits<Scalar>::denorm_min()), data2[1]); } if (PacketTraits::HasTanh) { @@ -618,7 +618,7 @@ void packetmath_real() { test::packet_helper<PacketTraits::HasLog, Packet> h; h.store(data2, internal::plog(h.load(data1))); VERIFY((numext::isnan)(data2[0])); - VERIFY_IS_EQUAL(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]); + VERIFY_IS_APPROX(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]); data1[0] = -std::numeric_limits<Scalar>::epsilon(); data1[1] = Scalar(0); @@ -629,7 +629,7 @@ void packetmath_real() { data1[0] = (std::numeric_limits<Scalar>::min)(); data1[1] = -(std::numeric_limits<Scalar>::min)(); h.store(data2, internal::plog(h.load(data1))); - VERIFY_IS_EQUAL(std::log((std::numeric_limits<Scalar>::min)()), data2[0]); + VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]); VERIFY((numext::isnan)(data2[1])); // Note: 32-bit arm always flushes denorms to zero. @@ -672,8 +672,10 @@ void packetmath_real() { VERIFY((numext::isnan)(data2[0])); VERIFY((numext::isnan)(data2[1])); } - // TODO(rmlarsen): Re-enable for bfloat16. - if (PacketTraits::HasCos && !internal::is_same<Scalar, bfloat16>::value) { + // TODO(rmlarsen): Re-enable for half and bfloat16. + if (PacketTraits::HasCos + && !internal::is_same<Scalar, half>::value + && !internal::is_same<Scalar, bfloat16>::value) { test::packet_helper<PacketTraits::HasCos, Packet> h; for (Scalar k = Scalar(1); k < Scalar(10000) / std::numeric_limits<Scalar>::epsilon(); k *= Scalar(2)) { for (int k1 = 0; k1 <= 1; ++k1) { @@ -1074,12 +1076,7 @@ EIGEN_DECLARE_TEST(packetmath) { CALL_SUBTEST_10(test::runner<uint64_t>::run()); CALL_SUBTEST_11(test::runner<std::complex<float> >::run()); CALL_SUBTEST_12(test::runner<std::complex<double> >::run()); -#if defined(EIGEN_VECTORIZE_AVX) - // AVX half packets not fully implemented. - CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>())); -#else CALL_SUBTEST_13(test::runner<half>::run()); -#endif CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>())); CALL_SUBTEST_15(test::runner<bfloat16>::run()); g_first_pass = false; |