aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src')
-rw-r--r--Eigen/src/Core/GenericPacketMath.h4
-rw-r--r--Eigen/src/Core/arch/AVX/MathFunctions.h26
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h26
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h4
-rw-r--r--Eigen/src/Core/arch/NEON/MathFunctions.h12
5 files changed, 68 insertions, 4 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 16119c1d8..b02a9f20b 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -442,7 +442,7 @@ template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
int exp;
EIGEN_USING_STD(frexp);
- Packet result = frexp(a, &exp);
+ Packet result = static_cast<Packet>(frexp(a, &exp));
exponent = static_cast<Packet>(exp);
return result;
}
@@ -453,7 +453,7 @@ EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pldexp(const Packet &a, const Packet &exponent) {
EIGEN_USING_STD(ldexp)
- return ldexp(a, static_cast<int>(exponent));
+ return static_cast<Packet>(ldexp(a, static_cast<int>(exponent)));
}
/** \internal \returns the min of \a a and \a b (coeff-wise) */
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index c0d362fa7..67041c812 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -184,6 +184,19 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
+template <>
+EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) {
+ Packet8f fexponent;
+ const Packet8h out = float2half(pfrexp<Packet8f>(half2float(a), fexponent));
+ exponent = float2half(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) {
+ return float2half(pldexp<Packet8f>(half2float(a), half2float(exponent)));
+}
+
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
@@ -195,6 +208,19 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
+template <>
+EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) {
+ Packet8f fexponent;
+ const Packet8bf out = F32ToBf16(pfrexp<Packet8f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) {
+ return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 66f3252cd..41929cb34 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -191,6 +191,32 @@ pexp<Packet8d>(const Packet8d& _x) {
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
+template <>
+EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) {
+ Packet16f fexponent;
+ const Packet16h out = float2half(pfrexp<Packet16f>(half2float(a), fexponent));
+ exponent = float2half(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) {
+ return float2half(pldexp<Packet16f>(half2float(a), half2float(exponent)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) {
+ Packet16f fexponent;
+ const Packet16bf out = F32ToBf16(pfrexp<Packet16f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) {
+ return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
+
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 9a1feb0d9..69c92a8cc 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -31,7 +31,7 @@ pfrexp_float(const Packet& a, Packet& exponent) {
const Packet cst_126f = pset1<Packet>(126.0f);
const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
- exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(a))), cst_126f);
+ exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f);
return por(pand(a, cst_inv_mant_mask), cst_half);
}
@@ -41,7 +41,7 @@ pfrexp_double(const Packet& a, Packet& exponent) {
const Packet cst_1022d = pset1<Packet>(1022.0);
const Packet cst_half = pset1<Packet>(0.5);
const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull));
- exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(a))), cst_1022d);
+ exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
}
diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h
index 28167b904..fa6615a85 100644
--- a/Eigen/src/Core/arch/NEON/MathFunctions.h
+++ b/Eigen/src/Core/arch/NEON/MathFunctions.h
@@ -44,6 +44,18 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
+template <>
+EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {
+ Packet4f fexponent;
+ const Packet4bf out = F32ToBf16(pfrexp<Packet4f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) {
+ return F32ToBf16(pldexp<Packet4f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
//---------- double ----------