From 95195bd0ef742f2876a1694ebb4c018f858af062 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 15 Dec 2016 09:34:54 -0800 Subject: Added support for AVX512 to fixed point instructions. (#6323) --- .../eigen3/unsupported/Eigen/CXX11/FixedPoint | 9 +- .../Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h | 2 + .../Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h | 545 +++++++++++++++++++++ .../Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h | 180 +++++++ 4 files changed, 735 insertions(+), 1 deletion(-) create mode 100644 third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h create mode 100644 third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h (limited to 'third_party/eigen3') diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint b/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint index 9d6b9c3f01..8e55a1f3e8 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint +++ b/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint @@ -31,8 +31,15 @@ #include "src/FixedPoint/FixedPointTypes.h" // Use optimized implementations whenever available -#ifdef EIGEN_VECTORIZE_AVX2 +#ifdef EIGEN_VECTORIZE_AVX512 +#include "src/Tensor/TensorContractionThreadPool.h" +#include "src/FixedPoint/PacketMathAVX512.h" +#include "src/FixedPoint/TypeCastingAVX512.h" + +#elif defined EIGEN_VECTORIZE_AVX2 #define EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT +#define EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT +#include "src/Tensor/TensorContractionThreadPool.h" #include "src/FixedPoint/PacketMathAVX2.h" #include "src/FixedPoint/MatMatProductAVX2.h" #include "src/FixedPoint/TypeCastingAVX2.h" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h index e71c2d8aea..98deb1742e 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h @@ -46,6 +46,7 @@ typedef struct Packet4q32i { Packet4q32i(__m128i val) : val(val) {} } Packet4q32i; +#ifndef EIGEN_VECTORIZE_AVX512 template <> struct packet_traits : default_packet_traits { typedef Packet32q8i type; @@ -112,6 +113,7 @@ struct packet_traits : default_packet_traits { HasSetLinear = 0 }; }; +#endif template <> struct unpacket_traits { diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h new file mode 100644 index 0000000000..b754bbf009 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h @@ -0,0 +1,545 @@ +#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ +#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ + +#include "PacketMathAVX2.h" + +namespace Eigen { +namespace internal { + +typedef struct Packet64q8i { + __m512i val; + operator __m512i() const { return val; } + Packet64q8i(); + Packet64q8i(__m512i val) : val(val) {} +} Packet64q8i; + +typedef struct Packet32q16i { + __m512i val; + operator __m512i() const { return val; } + Packet32q16i(); + Packet32q16i(__m512i val) : val(val) {} +} Packet32q16i; + +typedef struct Packet64q8u { + __m512i val; + operator __m512i() const { return val; } + Packet64q8u(); + Packet64q8u(__m512i val) : val(val) {} +} Packet64q8u; + +typedef struct Packet16q32i { + __m512i val; + operator __m512i() const { return val; } + Packet16q32i(); + Packet16q32i(__m512i val) : val(val) {} +} Packet16q32i; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet64q8i type; + typedef Packet32q8i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 64, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet64q8u type; + typedef Packet32q8u half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 64, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet32q16i type; + typedef Packet16q16i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet16q32i type; + typedef Packet8q32i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + }; + enum { + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; + +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet32q8i half; + enum { size = 64 }; +}; +template <> +struct unpacket_traits { + typedef QInt16 type; + typedef Packet16q16i half; + enum { size = 32 }; +}; +template <> +struct unpacket_traits { + typedef QUInt8 type; + typedef Packet32q8u half; + enum { size = 64 }; +}; +template <> +struct unpacket_traits { + typedef QInt32 type; + typedef Packet8q32i half; + enum { size = 16 }; +}; + +// Unaligned load +template <> +EIGEN_STRONG_INLINE Packet64q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i ploadu(const QInt16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u ploadu(const QUInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i ploadu(const QInt32* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} + +// Aligned load +template <> +EIGEN_STRONG_INLINE Packet64q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pload(const QInt16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pload(const QUInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pload(const QInt32* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} + +// Unaligned store +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet64q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt16* to, const Packet32q16i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QUInt8* to, const Packet64q8u& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt32* to, const Packet16q32i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.val); +} + +// Aligned store +template <> +EIGEN_STRONG_INLINE void pstore(QInt32* to, const Packet16q32i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QUInt8* to, const Packet64q8u& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet64q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt16* to, const Packet32q16i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.val); +} + +// Extract first element. +template <> +EIGEN_STRONG_INLINE QInt32 pfirst(const Packet16q32i& a) { + return _mm_cvtsi128_si32(_mm512_extracti32x4_epi32(a, 0)); +} +template <> +EIGEN_STRONG_INLINE QUInt8 pfirst(const Packet64q8u& a) { + return static_cast( + _mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0)); +} +template <> +EIGEN_STRONG_INLINE QInt8 pfirst(const Packet64q8i& a) { + return _mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0); +} +template <> +EIGEN_STRONG_INLINE QInt16 pfirst(const Packet32q16i& a) { + return _mm_extract_epi16(_mm512_extracti32x4_epi32(a.val, 0), 0); +} + +// Initialize to constant value. +template <> +EIGEN_STRONG_INLINE Packet64q8i pset1(const QInt8& from) { + return _mm512_set1_epi8(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pset1(const QInt16& from) { + return _mm512_set1_epi16(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pset1(const QUInt8& from) { + return _mm512_set1_epi8(static_cast(from.value)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pset1(const QInt32& from) { + return _mm512_set1_epi32(from.value); +} + +// Basic arithmetic packet ops for QInt32. +template <> +EIGEN_STRONG_INLINE Packet16q32i padd(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_add_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i psub(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_sub_epi32(a.val, b.val); +} +// Note: mullo truncates the result to 32 bits. +template <> +EIGEN_STRONG_INLINE Packet16q32i pmul(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_mullo_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pnegate(const Packet16q32i& a) { + return _mm512_sub_epi32(_mm512_setzero_si512(), a.val); +} + +// Min and max. +template <> +EIGEN_STRONG_INLINE Packet16q32i pmin(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_min_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pmax(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_max_epi32(a.val, b.val); +} + +template <> +EIGEN_STRONG_INLINE Packet64q8u pmin(const Packet64q8u& a, + const Packet64q8u& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epu8(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_min_epu8(ap0, bp0); + __m256i r1 = _mm256_min_epu8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pmax(const Packet64q8u& a, + const Packet64q8u& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epu8(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_max_epu8(ap0, bp0); + __m256i r1 = _mm256_max_epu8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet64q8i pmin(const Packet64q8i& a, + const Packet64q8i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epi8(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_min_epi8(ap0, bp0); + __m256i r1 = _mm256_min_epi8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pmin(const Packet32q16i& a, + const Packet32q16i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epi16(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_min_epi16(ap0, bp0); + __m256i r1 = _mm256_min_epi16(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet64q8i pmax(const Packet64q8i& a, + const Packet64q8i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epi8(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_max_epi8(ap0, bp0); + __m256i r1 = _mm256_max_epi8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pmax(const Packet32q16i& a, + const Packet32q16i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epi16(a.val, b.val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.val, 1); + __m256i r0 = _mm256_max_epi16(ap0, bp0); + __m256i r1 = _mm256_max_epi16(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} + +// Reductions. +template <> +EIGEN_STRONG_INLINE QInt32 predux_min(const Packet16q32i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3)); + res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst( + _mm_min_epi32( + res, + _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); +} +template <> +EIGEN_STRONG_INLINE QInt32 predux_max(const Packet16q32i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3)); + res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst( + _mm_max_epi32( + res, + _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); +} +template <> +EIGEN_STRONG_INLINE QInt16 predux_min(const Packet32q16i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3)); + res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 16), + static_cast(w) + }); +} +template <> +EIGEN_STRONG_INLINE QInt16 predux_max(const Packet32q16i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3)); + res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 16), + static_cast(w) + }); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_min(const Packet64q8u& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3)); + res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 24), + static_cast(w >> 16), + static_cast(w >> 8), + static_cast(w) + }); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet64q8u& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3)); + res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 24), + static_cast(w >> 16), + static_cast(w >> 8), + static_cast(w) + }); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_min(const Packet64q8i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3)); + res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 24), + static_cast(w >> 16), + static_cast(w >> 8), + static_cast(w) + }); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_max(const Packet64q8i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.val, 3); + Packet4i res = + _mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3)); + res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + std::uint32_t w = + pfirst( + _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min({ + static_cast(w >> 24), + static_cast(w >> 16), + static_cast(w >> 8), + static_cast(w) + }); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h new file mode 100644 index 0000000000..cd7120ec00 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h @@ -0,0 +1,180 @@ +#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ +#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ + +namespace Eigen { +namespace internal { + +typedef __m512 Packet16f; +typedef __m512i Packet16i; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet16f pcast(const Packet16q32i& a) { + return _mm512_cvtepi32_ps(a.val); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet16q32i pcast(const Packet16f& a) { + return _mm512_cvtps_epi32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q16i +pcast(const Packet16f& a, const Packet16f& b) { + Packet16i a_int = _mm512_cvtps_epi32(a); + Packet16i b_int = _mm512_cvtps_epi32(b); +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_packs_epi32(a_int, b_int); +#else + Packet8i ab_int16_low = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8( + _mm512_castsi256_si512(ab_int16_low), + ab_int16_high, 1); +#endif +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8i +pcast(const Packet16f& a, + const Packet16f& b, + const Packet16f& c, + const Packet16f& d) { + Packet16i a_int = _mm512_cvtps_epi32(a); + Packet16i b_int = _mm512_cvtps_epi32(b); + Packet16i c_int = _mm512_cvtps_epi32(c); + Packet16i d_int = _mm512_cvtps_epi32(d); +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_packs_epi16( + _mm512_packs_epi32(a_int, b_int), + _mm512_packs_epi32(c_int, d_int)); +#else + Packet8i ab_int16_low = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_low = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_castsi512_si256(c_int), + _mm512_castsi512_si256(d_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_high = + _mm256_permute4x64_epi64( + _mm256_packs_epi32( + _mm512_extracti32x8_epi32(c_int, 1), + _mm512_extracti32x8_epi32(d_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i abcd_int8_low = + _mm256_permute4x64_epi64( + _mm256_packs_epi16(ab_int16_low, cd_int16_low), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i abcd_int8_high = + _mm256_permute4x64_epi64( + _mm256_packs_epi16(ab_int16_high, cd_int16_high), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8( + _mm512_castsi256_si512(abcd_int8_low), + abcd_int8_high, 1); +#endif +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8i +pcast(const Packet16q32i& a, + const Packet16q32i& b, + const Packet16q32i& c, + const Packet16q32i& d) { + __m512i converted = _mm512_packs_epi16(_mm512_packs_epi32(a.val, b.val), + _mm512_packs_epi32(c.val, d.val)); + return converted; +} + +template <> +EIGEN_STRONG_INLINE Packet32q16i +pcast(const Packet16q32i& a, + const Packet16q32i& b) { + __m512i converted = _mm512_packs_epi32(a.val, b.val); + return converted; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8u +pcast(const Packet16q32i& a, const Packet16q32i& b, + const Packet16q32i& c, const Packet16q32i& d) { + const __m512i converted = _mm512_packus_epi16( + _mm512_packus_epi32(a.val, b.val), _mm512_packus_epi32(c.val, d.val)); + return converted; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +#if 0 +template <> +EIGEN_STRONG_INLINE Packet32q16u +pcast(const Packet16q32i& a, + const Packet16q32i& b) { + const __m512i converted = _mm512_packus_epi32(a.val, b.val); + return converted; +} +#endif + +} // end namespace internal +} // end namespace Eigen + +#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ -- cgit v1.2.3