diff options
Diffstat (limited to 'third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint')
7 files changed, 3032 insertions, 0 deletions
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h new file mode 100644 index 0000000000..564729ce48 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h @@ -0,0 +1,341 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_TYPES_H +#define EIGEN_CXX11_FIXED_POINT_TYPES_H + +#include <cmath> +#include <iostream> + +namespace Eigen { + +// The mantissa part of the fixed point representation. See +// go/tensorfixedpoint for details +struct QInt8; +struct QUInt8; +struct QInt16; +struct QUInt16; +struct QInt32; + +template <> +struct NumTraits<QInt8> : GenericNumTraits<int8_t> {}; +template <> +struct NumTraits<QUInt8> : GenericNumTraits<uint8_t> {}; +template <> +struct NumTraits<QInt16> : GenericNumTraits<int16_t> {}; +template <> +struct NumTraits<QUInt16> : GenericNumTraits<uint16_t> {}; +template <> +struct NumTraits<QInt32> : GenericNumTraits<int32_t> {}; + +namespace internal { +template <> +struct scalar_product_traits<QInt32, double> { + enum { + // Cost = NumTraits<T>::MulCost, + Defined = 1 + }; + typedef QInt32 ReturnType; +}; +} + +// Wrap the 8bit int into a QInt8 struct instead of using a typedef to prevent +// the compiler from silently type cast the mantissa into a bigger or a smaller +// representation. +struct QInt8 { + QInt8() {} + QInt8(const int8_t v) : value(v) {} + QInt8(const QInt32 v); + + operator int() const { return static_cast<int>(value); } + + int8_t value; +}; + +struct QUInt8 { + QUInt8() {} + QUInt8(const uint8_t v) : value(v) {} + QUInt8(const QInt32 v); + + operator int() const { return static_cast<int>(value); } + + uint8_t value; +}; + +struct QInt16 { + QInt16() {} + QInt16(const int16_t v) : value(v) {} + QInt16(const QInt32 v); + operator int() const { return static_cast<int>(value); } + + int16_t value; +}; + +struct QUInt16 { + QUInt16() {} + QUInt16(const uint16_t v) : value(v) {} + QUInt16(const QInt32 v); + operator int() const { return static_cast<int>(value); } + + uint16_t value; +}; + +struct QInt32 { + QInt32() {} + QInt32(const int8_t v) : value(v) {} + QInt32(const int32_t v) : value(v) {} + QInt32(const QInt8 v) : value(v.value) {} + QInt32(const float v) : value(static_cast<int32_t>(lrint(v))) {} +#ifdef EIGEN_MAKING_DOCS + // Workaround to fix build on PPC. + QInt32(unsigned long v) : value(v) {} +#endif + + operator float() const { return static_cast<float>(value); } + + int32_t value; +}; + +EIGEN_STRONG_INLINE QInt8::QInt8(const QInt32 v) + : value(v.value > 127 ? 127 : (v.value < -128 ? -128 : v.value)) {} +EIGEN_STRONG_INLINE QUInt8::QUInt8(const QInt32 v) + : value(v.value > 255 ? 255 : (v.value < 0 ? 0 : v.value)) {} +EIGEN_STRONG_INLINE QInt16::QInt16(const QInt32 v) + : value(v.value > 32767 ? 32767 : (v.value < -32768 ? -32768 : v.value)) {} +EIGEN_STRONG_INLINE QUInt16::QUInt16(const QInt32 v) + : value(v.value > 65535 ? 65535 : (v.value < 0 ? 0 : v.value)) {} + +// Basic widening 8-bit operations: This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt8 b) { + return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QUInt8 b) { + return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt8 b) { + return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt8 b) { + return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value)); +} + +// Basic widening 16-bit operations: This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt16 b) { + return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QUInt16 b) { + return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt16 b) { + return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt16 b) { + return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value)); +} + +// Mixed QInt32 op QInt8 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt8 b) { + return QInt32(a.value + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt8 b) { + return QInt32(a.value - static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt8 b) { + return QInt32(a.value * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) * b.value); +} + +// Mixed QInt32 op QInt16 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt16 b) { + return QInt32(a.value + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt16 b) { + return QInt32(a.value - static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt16 b) { + return QInt32(a.value * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) * b.value); +} + +// Mixed QInt32 op QUInt8 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt8 b) { + return QInt32(a.value + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt8 b) { + return QInt32(a.value - static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt8 b) { + return QInt32(a.value * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) * b.value); +} + +// Mixed QInt32 op QUInt16 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt16 b) { + return QInt32(a.value + static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt16 b) { + return QInt32(a.value - static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt16 b) { + return QInt32(a.value * static_cast<int32_t>(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast<int32_t>(a.value) * b.value); +} + +// Basic arithmetic operations on QInt32, which behaves like a int32_t. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt32 b) { + return a.value + b.value; +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt32 b) { + return a.value - b.value; +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt32 b) { + return a.value * b.value; +} +EIGEN_STRONG_INLINE QInt32 operator/(const QInt32 a, const QInt32 b) { + return a.value / b.value; +} +EIGEN_STRONG_INLINE QInt32& operator+=(QInt32& a, const QInt32 b) { + a.value += b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator-=(QInt32& a, const QInt32 b) { + a.value -= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const QInt32 b) { + a.value *= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) { + a.value /= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { + return -a.value; +} + +// Scaling QInt32 by double. We do the arithmetic in double because +// float only has 23 bits of mantissa, so casting QInt32 to float might reduce +// accuracy by discarding up to 7 (least significant) bits. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const double b) { + return static_cast<int32_t>(lrint(static_cast<double>(a.value) * b)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const double a, const QInt32 b) { + return static_cast<int32_t>(lrint(a * static_cast<double>(b.value))); +} +EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const double b) { + a.value = static_cast<int32_t>(lrint(static_cast<double>(a.value) * b)); + return a; +} + +// Comparisons +EIGEN_STRONG_INLINE bool operator==(const QInt8 a, const QInt8 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QUInt8 a, const QUInt8 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QInt16 a, const QInt16 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QUInt16 a, const QUInt16 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QInt32 a, const QInt32 b) { + return a.value == b.value; +} + +EIGEN_STRONG_INLINE bool operator<(const QInt8 a, const QInt8 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QUInt8 a, const QUInt8 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QInt16 a, const QInt16 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QUInt16 a, const QUInt16 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QInt32 a, const QInt32 b) { + return a.value < b.value; +} + +EIGEN_STRONG_INLINE bool operator>(const QInt8 a, const QInt8 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QUInt8 a, const QUInt8 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QInt16 a, const QInt16 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QUInt16 a, const QUInt16 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QInt32 a, const QInt32 b) { + return a.value > b.value; +} + +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt8 a) { + os << static_cast<int>(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt8 a) { + os << static_cast<int>(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt16 a) { + os << static_cast<int>(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt16 a) { + os << static_cast<int>(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt32 a) { + os << a.value; + return os; +} + +} // namespace Eigen + +#endif // EIGEN_CXX11_FIXED_POINT_TYPES_H diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h new file mode 100644 index 0000000000..4d0dca07df --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h @@ -0,0 +1,255 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H +#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H + + +namespace Eigen { +namespace internal { + +// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent +// overflows +template<> struct scalar_product_traits<QInt8, QInt8> +{ + enum { + Defined = 1 + }; + typedef QInt32 ReturnType; +}; + +// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits +// to prevent overflows +template<> struct scalar_product_traits<QInt8, QUInt8> +{ + enum { + Defined = 1 + }; + typedef QInt32 ReturnType; +}; + +// Description of the product implementation. It's pretty simple now since +// nothing is vectorized yet. +// This definition tackle the case where both lhs and rhs are encoded using +// signed 8bit integers +#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT + +template<bool _ConjLhs, bool _ConjRhs> +class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs> +{ +public: + typedef QInt8 LhsScalar; + typedef QInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// The signed 8bit Mat-Mat product itself. +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + + +// This definition tackle the case where the lhs is encoded using signed 8bit +// integers and the rhs using unsigned 8bit integers. +#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT +template<bool _ConjLhs, bool _ConjRhs> +class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> +{ +public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + + +// This definition tackle the case where the khs is encoded using unsigned 8bit +// integers and the rhs using signed 8bit integers. +#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT +template<bool _ConjLhs, bool _ConjRhs> +class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs> +{ +public: + typedef QUInt8 LhsScalar; + typedef QInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + + +// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +} // namespace internal +} // namespace Eigen + + + +#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h new file mode 100644 index 0000000000..d561b79fbd --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h @@ -0,0 +1,1743 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// Copyright (C) 2015 Matthew Sarett <msarett@google.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H +#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H + +namespace Eigen { +namespace internal { + +// AVX2 optimized implementation of Mat-Mat product. +// LHS is encoded using signed 8-bit integers. +// RHS is encoded using unsigned 8-bit integers. +#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +// Define quantized traits +template<bool _ConjLhs, bool _ConjRhs> +class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> +{ +public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // Define register blocking scheme. + nr = 32, + mr = 32, + kr = 8, + // Ignore progress tracking per loop iteration. + LhsProgress = -1, + RhsProgress = -1 + }; +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContractionThreadPool, inputs must have dimensions that are +// multiples of 32. +template<int KcFactor, typename Index> +struct ComputeGemmByColBlockingSizes<QInt8, QUInt8, KcFactor, Index> { + void operator()(Index& k, Index& m, Index& n, Index num_threads) + { + eigen_assert(m % 32 == 0); + eigen_assert(n % 32 == 0); + eigen_assert(k % 32 == 0); + if (!k || !m || !n) { + return; + } + n = (((n / num_threads) + 31) / 32) * 32; + } +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContractionThreadPool, inputs must have dimensions that are +// multiples of 32. +template<int KcFactor, typename Index> +struct ComputeGemmByRowBlockingSizes<QInt8, QUInt8, KcFactor, Index> { + void operator()(Index& k, Index& m, Index& n, Index num_threads) + { + eigen_assert(m % 32 == 0); + eigen_assert(n % 32 == 0 || n == 1); + eigen_assert(k % 32 == 0); + if (!k || !m || !n) { + return; + } + // Special case to avoid breaking the unimplemented matrix-vector case + if (n == 1) { + n = 32; + } + m = (((m / num_threads) + 31) / 32) * 32; + } +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to +// multiples of 32. +template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> +class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth, + KcFactor, false> + : public level3_blocking<QInt8, QInt8> { + DenseIndex m_sizeA; + DenseIndex m_sizeB; + + public: + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, + DenseIndex /*num_threads*/, bool /*l3_blocking*/) { + this->m_mc = ((rows + 31) / 32) * 32; + this->m_nc = ((cols + 31) / 32) * 32; + this->m_kc = ((depth + 31) / 32) * 32; + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + } + void allocateA() { + if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA); + } + void allocateB() { + if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt8>(m_sizeB); + } + void allocateAll() { + allocateA(); + allocateB(); + } + ~gemm_blocking_space() { + aligned_delete(this->m_blockA, m_sizeA); + aligned_delete(this->m_blockB, m_sizeB); + } +}; + + +template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> +class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth, + KcFactor, false> + : public level3_blocking<QInt8, QUInt8> { + DenseIndex m_sizeA; + DenseIndex m_sizeB; + + public: + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, + DenseIndex /*num_threads*/, bool /*l3_blocking*/) { + this->m_mc = ((rows + 31) / 32) * 32; + this->m_nc = ((cols + 31) / 32) * 32; + this->m_kc = ((depth + 31) / 32) * 32; + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + } + void allocateA() { + if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA); + } + void allocateB() { + if (this->m_blockB == 0) this->m_blockB = aligned_new<QUInt8>(m_sizeB); + } + void allocateAll() { + allocateA(); + allocateB(); + } + ~gemm_blocking_space() { + aligned_delete(this->m_blockA, m_sizeA); + aligned_delete(this->m_blockB, m_sizeB); + } +}; + +// Alternate templates for any input sizes +template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false> +struct gemm_pack_lhs_any; +template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> { + EIGEN_DONT_INLINE void operator() + (QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0); +}; + +template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> +struct gemm_pack_rhs_any; +template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> { + EIGEN_DONT_INLINE void operator() + (QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0); +}; + +template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> +struct gebp_kernel_any; +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef typename DataMapper::LinearMapper LinearMapper; + + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +// Alternate implementations for any input sizes +template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode>:: +operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + // Get vector pointer + __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); + + // Get even multiples of the dimensions + Index rows_32 = (rows / 32) * 32; + Index depth_8 = (depth / 8) * 8; + + // Get padding for when depth is not a multiple of 32 + int padding = 0; + if (depth % 32 != 0) { + int depth_32 = (depth / 32) * 32; + int extra_depth = depth - depth_32; + int extra_depth_8 = ((extra_depth + 7) / 8) * 8; + padding = 32 - extra_depth_8; + } + + // Pack rows in sets of 32 + for (Index m = 0; m < rows_32; m += 32) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth_8; k += 8) { + // Load vectors + __m256i L_A = lhs.loadPacket(m, k); + __m256i L_B = lhs.loadPacket(m, k + 1); + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_C = lhs.loadPacket(m, k + 2); + __m256i L_D = lhs.loadPacket(m, k + 3); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_E = lhs.loadPacket(m, k + 4); + __m256i L_F = lhs.loadPacket(m, k + 5); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_G = lhs.loadPacket(m, k + 6); + __m256i L_H = lhs.loadPacket(m, k + 7); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + + // Finish the k dimension, padding with zeros + if (depth_8 < depth) { + __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; + switch (depth - depth_8) { + case 1: + L_A = lhs.loadPacket(m, depth_8); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 2: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 3: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = lhs.loadPacket(m, depth_8 + 2); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 4: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = lhs.loadPacket(m, depth_8 + 2); + L_D = lhs.loadPacket(m, depth_8 + 3); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 5: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = lhs.loadPacket(m, depth_8 + 2); + L_D = lhs.loadPacket(m, depth_8 + 3); + L_E = lhs.loadPacket(m, depth_8 + 4); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 6: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = lhs.loadPacket(m, depth_8 + 2); + L_D = lhs.loadPacket(m, depth_8 + 3); + L_E = lhs.loadPacket(m, depth_8 + 4); + L_F = lhs.loadPacket(m, depth_8 + 5); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 7: + L_A = lhs.loadPacket(m, depth_8); + L_B = lhs.loadPacket(m, depth_8 + 1); + L_C = lhs.loadPacket(m, depth_8 + 2); + L_D = lhs.loadPacket(m, depth_8 + 3); + L_E = lhs.loadPacket(m, depth_8 + 4); + L_F = lhs.loadPacket(m, depth_8 + 5); + L_G = lhs.loadPacket(m, depth_8 + 6); + L_H = _mm256_setzero_si256(); + break; + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + blockA_256 += padding; + } + + // Finish the m dimension, padding with zeros + if (rows_32 < rows) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth_8; k += 8) { + // Load vectors + __m256i L_A = _mm256_setzero_si256(); + __m256i L_B = _mm256_setzero_si256(); + __m256i L_C = _mm256_setzero_si256(); + __m256i L_D = _mm256_setzero_si256(); + __m256i L_E = _mm256_setzero_si256(); + __m256i L_F = _mm256_setzero_si256(); + __m256i L_G = _mm256_setzero_si256(); + __m256i L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + QInt8* ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, k); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, k + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, k + 2); + ptr = (QInt8*) &L_D; + ptr[m] = lhs(rows_32 + m, k + 3); + ptr = (QInt8*) &L_E; + ptr[m] = lhs(rows_32 + m, k + 4); + ptr = (QInt8*) &L_F; + ptr[m] = lhs(rows_32 + m, k + 5); + ptr = (QInt8*) &L_G; + ptr[m] = lhs(rows_32 + m, k + 6); + ptr = (QInt8*) &L_H; + ptr[m] = lhs(rows_32 + m, k + 7); + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + + // Finish the k dimension, padding with zeros + if (depth_8 < depth) { + __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; + QInt8* ptr; + switch (depth - depth_8) { + case 1: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + QInt8* ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + } + break; + case 2: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + } + break; + case 3: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + } + break; + case 4: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*) &L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + } + break; + case 5: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*) &L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*) &L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + } + break; + case 6: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*) &L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*) &L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*) &L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + } + break; + case 7: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*) &L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*) &L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*) &L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*) &L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*) &L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*) &L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + ptr = (QInt8*) &L_G; + ptr[m] = lhs(rows_32 + m, depth_8 + 6); + } + break; + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + } +} + +template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>:: +operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + // Get vector pointer + __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); + + // Get even multiples of the dimensions + Index cols_32 = (cols / 32) * 32; + Index depth_32 = (depth / 32) * 32; + + // Perform a step of the packing for 4 columns + __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; +#define PACK_STEP \ + R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ + R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ + R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ + R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ + R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ + R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ + R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ + R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ + _mm256_store_si256(blockB_256, R_AD_0); \ + _mm256_store_si256(blockB_256 + 8, R_AD_8); \ + _mm256_store_si256(blockB_256 + 16, R_AD_16); \ + _mm256_store_si256(blockB_256 + 24, R_AD_24); \ + blockB_256++; + + // Pack cols in sets of 32 + for (Index n = 0; n < cols_32; n += 32) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth_32; k += 32) { + __m256i R_A = rhs.loadPacket(k, n); + __m256i R_B = rhs.loadPacket(k, n + 1); + __m256i R_C = rhs.loadPacket(k, n + 2); + __m256i R_D = rhs.loadPacket(k, n + 3); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 4); + R_B = rhs.loadPacket(k, n + 5); + R_C = rhs.loadPacket(k, n + 6); + R_D = rhs.loadPacket(k, n + 7); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 8); + R_B = rhs.loadPacket(k, n + 9); + R_C = rhs.loadPacket(k, n + 10); + R_D = rhs.loadPacket(k, n + 11); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 12); + R_B = rhs.loadPacket(k, n + 13); + R_C = rhs.loadPacket(k, n + 14); + R_D = rhs.loadPacket(k, n + 15); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 16); + R_B = rhs.loadPacket(k, n + 17); + R_C = rhs.loadPacket(k, n + 18); + R_D = rhs.loadPacket(k, n + 19); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 20); + R_B = rhs.loadPacket(k, n + 21); + R_C = rhs.loadPacket(k, n + 22); + R_D = rhs.loadPacket(k, n + 23); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 24); + R_B = rhs.loadPacket(k, n + 25); + R_C = rhs.loadPacket(k, n + 26); + R_D = rhs.loadPacket(k, n + 27); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 28); + R_B = rhs.loadPacket(k, n + 29); + R_C = rhs.loadPacket(k, n + 30); + R_D = rhs.loadPacket(k, n + 31); + PACK_STEP; + + blockB_256 += 24; + } + + if (depth_32 < depth) { + QUInt8* ptr; + __m256i R_A = _mm256_setzero_si256(); + __m256i R_B = _mm256_setzero_si256(); + __m256i R_C = _mm256_setzero_si256(); + __m256i R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 2); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 3); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 4); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 5); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 6); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 7); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 8); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 9); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 10); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 11); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 12); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 13); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 14); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 15); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 16); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 17); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 18); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 19); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 20); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 21); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 22); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 23); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 24); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 25); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 26); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 27); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n + 28); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 29); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 30); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 31); + } + PACK_STEP; + blockB_256 += 24; + } + } + + // Finish packing cols + if (cols_32 < cols) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth_32; k += 32) { + __m256i R_A, R_B, R_C, R_D; + Index n; + for (n = cols_32; n < cols; n += 4) { + switch (cols - n) { + case 1: + R_A = rhs.loadPacket(k, n); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 2: + R_A = rhs.loadPacket(k, n); + R_B = rhs.loadPacket(k, n + 1); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 3: + R_A = rhs.loadPacket(k, n); + R_B = rhs.loadPacket(k, n + 1); + R_C = rhs.loadPacket(k, n + 2); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + default: + R_A = rhs.loadPacket(k, n); + R_B = rhs.loadPacket(k, n + 1); + R_C = rhs.loadPacket(k, n + 2); + R_D = rhs.loadPacket(k, n + 3); + PACK_STEP; + break; + } + } + + // Increment the block pointer. + // We must pad if cols is not a multiple of 32. + blockB_256 += 32 - (n - cols_32) / 4; + } + + if (depth_32 < depth) { + for (Index n = cols_32; n < cols; n += 4) { + QUInt8* ptr; + __m256i R_A = _mm256_setzero_si256(); + __m256i R_B = _mm256_setzero_si256(); + __m256i R_C = _mm256_setzero_si256(); + __m256i R_D = _mm256_setzero_si256(); + switch (cols - n) { + case 1: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n); + } + PACK_STEP; + break; + case 2: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 1); + } + PACK_STEP; + break; + case 3: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 2); + } + PACK_STEP; + break; + default: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*) &R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*) &R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*) &R_C; + ptr[k - depth_32] = rhs(k, n + 2); + ptr = (QUInt8*) &R_D; + ptr[k - depth_32] = rhs(k, n + 3); + } + PACK_STEP; + break; + } + } + } + } +#undef PACK_STEP +} + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + Index rows_32 = ((rows + 31) / 32) * 32; + Index cols_32 = ((cols + 31) / 32) * 32; + Index depth_32 = ((depth + 31) / 32) * 32; + + // Create result block + ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + + // Get vectorized pointers + __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); + const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA); + const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB); + + // Loop over blocks of 32 columns + for (Index n = 0; n < cols_32; n += 32) { + // Reset index into blockA + Index indexL = 0; + // Loop over blocks of 32 rows + for (Index m = 0; m < rows_32; m += 32) { + // Reset index into blockB + Index indexR = n / 32 * depth_32; + // Loop over blocks of 8 on depth + for (Index k = 0; k < depth_32; k += 8) { + // Load inputs + __m256i L_AD0 = blockA_256[indexL++]; + __m256i L_AD8 = blockA_256[indexL++]; + __m256i L_AD16 = blockA_256[indexL++]; + __m256i L_AD24 = blockA_256[indexL++]; + __m256i L_EH0 = blockA_256[indexL++]; + __m256i L_EH8 = blockA_256[indexL++]; + __m256i L_EH16 = blockA_256[indexL++]; + __m256i L_EH24 = blockA_256[indexL++]; + __m256i R_AH0 = blockB_256[indexR++]; + __m256i R_AH4 = blockB_256[indexR++]; + __m256i R_AH8 = blockB_256[indexR++]; + __m256i R_AH12 = blockB_256[indexR++]; + __m256i R_AH16 = blockB_256[indexR++]; + __m256i R_AH20 = blockB_256[indexR++]; + __m256i R_AH24 = blockB_256[indexR++]; + __m256i R_AH28 = blockB_256[indexR++]; + + // This constant is used with madd to convert 16 bit to 32 bit + const __m256i ONE = _mm256_set1_epi32(0x00010001); + + // Declare variables used in COMPUTE_STEP + __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; + +#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 1, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 2, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 3, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); + + // Permute and shuffle to copy a single value across the entire vector + // Then compute the multiplication + __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); + __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 0); + __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 1); + R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); + __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 2); + __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 3); + + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 4); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 5); + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 6); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 7); + + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 8); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 9); + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 10); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 11); + + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 12); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 13); + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 14); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 15); + + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 16); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 17); + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 18); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 19); + + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 20); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 21); + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 22); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 23); + + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 24); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 25); + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 26); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 27); + + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 28); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 29); + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 30); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 31); + +#undef COMPUTE_STEP + } + + // Transfer the results to the result matrix. + if (m + 32 <= rows && n + 32 <= cols) { + Index i = 0; + for (Index j = n; j < n + 32; j++) { + LinearMapper r0 = res.getLinearMapper(m, j); + LinearMapper r1 = res.getLinearMapper(m + 8, j); + LinearMapper r2 = res.getLinearMapper(m + 16, j); + LinearMapper r3 = res.getLinearMapper(m + 24, j); + r0.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); + r1.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); + r2.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0))); + r3.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0))); + } + } + else { + for (Index j = n; j < cols; j++) { + for (Index i = m; i < rows; i++) { + res(i, j) = blockO[(j - n) * 32 + (i - m)]; + } + } + } + + // Zero the result block so it can be reused + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + } + } +} + +// Below are the fully optimized versions that are correct only for sizes that +// are multiple of 32. It is about a 10% performance benefit to keep these +// implementations separate. + +// Arrange a block of the left input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields output (A0 beside B0 in memory): +// A0 B0 C0 D0 +// A1 B1 C1 D1 +// A2 B2 C2 D2 +// A3 B3 C3 D3 +// A4 B4 C4 D4 +// A5 B5 C5 D5 +// A6 B6 C6 D6 +// A7 B7 C7 D7 +// ... +// A31 B31 C31 D31 +// E0 F0 G0 H0 +// E1 F1 G1 H1 +// E2 F2 G2 H2 +// E3 F3 G3 H3 +// E4 F4 G4 H4 +// E5 F5 G5 H5 +// E6 F6 G6 H6 +// E7 F7 G7 H7 +// ... +// +// Four elements of the same row are arranged contiguously because maddubs and +// madd both perform an adjacent addition in the kernel. +template <typename Index, typename DataMapper, int Pack1, int Pack2, + bool Conjugate, bool PanelMode> +struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, + Conjugate, PanelMode> { + EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, + Index depth, Index rows, Index stride = 0, + Index offset = 0); +}; + +template <typename Index, typename DataMapper, int Pack1, int Pack2, + bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, + ColMajor, Conjugate, PanelMode>:: +operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + // Use alternate function for weird sizes + if (rows % 32 != 0 || depth % 32 != 0) { + gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> lhs_pack; + return lhs_pack(blockA, lhs, depth, rows, stride, offset); + } + + // Get vector pointer + __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); + + // Pack rows in sets of 32 + for (Index m = 0; m < rows; m += 32) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth; k += 8) { + // Load vectors + __m256i L_A = lhs.loadPacket(m, k); + __m256i L_B = lhs.loadPacket(m, k + 1); + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_C = lhs.loadPacket(m, k + 2); + __m256i L_D = lhs.loadPacket(m, k + 3); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_E = lhs.loadPacket(m, k + 4); + __m256i L_F = lhs.loadPacket(m, k + 5); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_G = lhs.loadPacket(m, k + 6); + __m256i L_H = lhs.loadPacket(m, k + 7); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + } +} + +// Arrange a block of the right input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields row major output (A0 beside A1 in memory): +// A0 A1 A2 A3 A4 A5 A6 A7 +// B0 B1 B2 B3 B4 B5 B6 B7 +// ... +// +// At least four elements of the same col are arranged contiguously because +// maddubs and madd both perform an adjacent addition in the kernel. We can +// save work by leaving 8 adjacent elements because kr = 8. +template <typename Index, typename DataMapper, int nr, bool Conjugate, + bool PanelMode> +struct gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, + PanelMode> { + EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0); +}; + +template <typename Index, typename DataMapper, int nr, bool Conjugate, + bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, + Conjugate, PanelMode>:: +operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + // Use alternate function for weird sizes + if (cols % 32 != 0 || depth % 32 != 0) { + gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> rhs_pack; + return rhs_pack(blockB, rhs, depth, cols, stride, offset); + } + + // Get vector pointer + __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); + + // Perform a step of the packing for 4 columns + __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; +#define PACK_STEP \ + R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ + R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ + R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ + R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ + R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ + R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ + R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ + R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ + _mm256_store_si256(blockB_256, R_AD_0); \ + _mm256_store_si256(blockB_256 + 8, R_AD_8); \ + _mm256_store_si256(blockB_256 + 16, R_AD_16); \ + _mm256_store_si256(blockB_256 + 24, R_AD_24); \ + blockB_256++; + + // Pack cols in sets of 32 + for (Index n = 0; n < cols; n += 32) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth; k += 32) { + __m256i R_A = rhs.loadPacket(k, n); + __m256i R_B = rhs.loadPacket(k, n + 1); + __m256i R_C = rhs.loadPacket(k, n + 2); + __m256i R_D = rhs.loadPacket(k, n + 3); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 4); + R_B = rhs.loadPacket(k, n + 5); + R_C = rhs.loadPacket(k, n + 6); + R_D = rhs.loadPacket(k, n + 7); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 8); + R_B = rhs.loadPacket(k, n + 9); + R_C = rhs.loadPacket(k, n + 10); + R_D = rhs.loadPacket(k, n + 11); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 12); + R_B = rhs.loadPacket(k, n + 13); + R_C = rhs.loadPacket(k, n + 14); + R_D = rhs.loadPacket(k, n + 15); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 16); + R_B = rhs.loadPacket(k, n + 17); + R_C = rhs.loadPacket(k, n + 18); + R_D = rhs.loadPacket(k, n + 19); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 20); + R_B = rhs.loadPacket(k, n + 21); + R_C = rhs.loadPacket(k, n + 22); + R_D = rhs.loadPacket(k, n + 23); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 24); + R_B = rhs.loadPacket(k, n + 25); + R_C = rhs.loadPacket(k, n + 26); + R_D = rhs.loadPacket(k, n + 27); + PACK_STEP; + + R_A = rhs.loadPacket(k, n + 28); + R_B = rhs.loadPacket(k, n + 29); + R_C = rhs.loadPacket(k, n + 30); + R_D = rhs.loadPacket(k, n + 31); + PACK_STEP; + + blockB_256 += 24; + } + } +#undef PACK_STEP +} + +// Perform the actual multiplication on packed inputs +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef typename DataMapper::LinearMapper LinearMapper; + + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + // Use alternate function for weird sizes + if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) { + gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> gebp; + return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + + // Create result block + QInt32* blockO = aligned_new<QInt32>(32 * 32); + // Allocating the result block is about 5-10% faster than declaring stack + // space. It is unclear why this is the case. + // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + + // Get vectorized pointers + __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); + const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA); + const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB); + + // Loop over blocks of 32 columns + for (Index n = 0; n < cols; n += 32) { + // Reset index into blockA + Index indexL = 0; + // Loop over blocks of 32 rows + for (Index m = 0; m < rows; m += 32) { + // Reset index into blockB + Index indexR = n / 32 * depth; + // Loop over blocks of 8 on depth + for (Index k = 0; k < depth; k += 8) { + // Load inputs + __m256i L_AD0 = blockA_256[indexL++]; + __m256i L_AD8 = blockA_256[indexL++]; + __m256i L_AD16 = blockA_256[indexL++]; + __m256i L_AD24 = blockA_256[indexL++]; + __m256i L_EH0 = blockA_256[indexL++]; + __m256i L_EH8 = blockA_256[indexL++]; + __m256i L_EH16 = blockA_256[indexL++]; + __m256i L_EH24 = blockA_256[indexL++]; + __m256i R_AH0 = blockB_256[indexR++]; + __m256i R_AH4 = blockB_256[indexR++]; + __m256i R_AH8 = blockB_256[indexR++]; + __m256i R_AH12 = blockB_256[indexR++]; + __m256i R_AH16 = blockB_256[indexR++]; + __m256i R_AH20 = blockB_256[indexR++]; + __m256i R_AH24 = blockB_256[indexR++]; + __m256i R_AH28 = blockB_256[indexR++]; + + // This constant is used with madd to convert 16 bit to 32 bit + const __m256i ONE = _mm256_set1_epi32(0x00010001); + + // Declare variables used in COMPUTE_STEP + __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; + +#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 1, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 2, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 3, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); + + // Permute and shuffle to copy a single value across the entire vector + // Then compute the multiplication + __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); + __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 0); + __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 1); + R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); + __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 2); + __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 3); + + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 4); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 5); + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 6); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 7); + + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 8); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 9); + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 10); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 11); + + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 12); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 13); + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 14); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 15); + + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 16); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 17); + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 18); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 19); + + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 20); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 21); + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 22); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 23); + + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 24); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 25); + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 26); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 27); + + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 28); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 29); + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 30); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 31); + +#undef COMPUTE_STEP + } + + // Transfer the results to the result matrix + Index i = 0; + for (Index j = n; j < n + 32; j++) { + LinearMapper r0 = res.getLinearMapper(m, j); + LinearMapper r1 = res.getLinearMapper(m + 8, j); + LinearMapper r2 = res.getLinearMapper(m + 16, j); + LinearMapper r3 = res.getLinearMapper(m + 24, j); + r0.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); + r1.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); + r2.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0))); + r3.storePacket( + 0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0))); + } + + // Zero the result block so it can be reused + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + } + } + aligned_delete(blockO, 32 * 32); +} + +#endif // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h new file mode 100644 index 0000000000..99894cafb5 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h @@ -0,0 +1,95 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// Copyright (C) 2015 Benoit Jacob <benoitjacob@google.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H +#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H + + +namespace Eigen { +namespace internal { + + +// AVX2 optimized implementation of the case where the lhs is encoded using signed 8bit +// integers and the rhs using unsigned 8bit integers. +#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +template<bool _ConjLhs, bool _ConjRhs> +class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> +{ +public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +EIGEN_DONT_INLINE +void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + + +} // namespace internal +} // namespace Eigen + + + +#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h new file mode 100644 index 0000000000..18b5085b89 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h @@ -0,0 +1,123 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H +#define EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H + + +namespace Eigen { +namespace internal { + +// Mat-Vec product +// Both lhs and rhs are encoded as 8bit signed integers +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version> +{ +EIGEN_DONT_INLINE static void run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QInt8 alpha); +}; + +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QInt8 alpha) +{ + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + + +// Mat-Vec product +// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned integers +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version> +{ +EIGEN_DONT_INLINE static void run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QUInt8 alpha); +}; + +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>::run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QUInt8 alpha) +{ + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + + +// Mat-Vec product +// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed integers +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +struct general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version> +{ +EIGEN_DONT_INLINE static void run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QInt8 alpha); +}; + +template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> +EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run( + Index rows, Index cols, + const LhsMapper& lhs, + const RhsMapper& rhs, + QInt32* res, Index resIncr, + QInt8 alpha) +{ + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + +} // namespace internal +} // namespace Eigen + + + +#endif // EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h new file mode 100644 index 0000000000..cae1a0b06d --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h @@ -0,0 +1,409 @@ +#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ +#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ + +namespace Eigen { +namespace internal { + +typedef struct Packet32q8i { + __m256i val; + operator __m256i() const { return val; } + Packet32q8i(); + Packet32q8i(__m256i val) : val(val) {} +} Packet32q8i; + +typedef struct Packet32q8u { + __m256i val; + operator __m256i() const { return val; } + Packet32q8u(); + Packet32q8u(__m256i val) : val(val) {} +} Packet32q8u; + +typedef struct Packet16q8i { + __m128i val; + operator __m128i() const { return val; } + Packet16q8i(); + Packet16q8i(__m128i val) : val(val) {} +} Packet16q8i; + +typedef struct Packet16q8u { + __m128i val; + operator __m128i() const { return val; } + Packet16q8u(); + Packet16q8u(__m128i val) : val(val) {} +} Packet16q8u; + +typedef struct Packet8q32i { + __m256i val; + operator __m256i() const { return val; } + Packet8q32i(); + Packet8q32i(__m256i val) : val(val) {} +} Packet8q32i; + +typedef struct Packet4q32i { + __m128i val; + operator __m128i() const { return val; } + Packet4q32i(); + Packet4q32i(__m128i val) : val(val) {} +} Packet4q32i; + +template <> +struct packet_traits<QInt8> : default_packet_traits { + typedef Packet32q8i type; + typedef Packet16q8i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits<QUInt8> : default_packet_traits { + typedef Packet32q8u type; + typedef Packet16q8u half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits<QInt32> : default_packet_traits { + typedef Packet8q32i type; + typedef Packet4q32i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + }; + enum { + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; + +template <> +struct unpacket_traits<Packet32q8i> { + typedef QInt8 type; + typedef Packet16q8i half; + enum { size = 32 }; +}; +template <> +struct unpacket_traits<Packet32q8u> { + typedef QUInt8 type; + typedef Packet16q8u half; + enum { size = 32 }; +}; +template <> +struct unpacket_traits<Packet8q32i> { + typedef QInt32 type; + typedef Packet4q32i half; + enum { size = 8 }; +}; + +// Unaligned load +template <> +EIGEN_STRONG_INLINE Packet32q8i ploadu<Packet32q8i>(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u ploadu<Packet32q8u>(const QUInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i ploadu<Packet8q32i>(const QInt32* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(from)); +} + +// Aligned load +template <> +EIGEN_STRONG_INLINE Packet32q8i pload<Packet32q8i>(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast<const __m256i*>(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pload<Packet32q8u>(const QUInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast<const __m256i*>(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pload<Packet8q32i>(const QInt32* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast<const __m256i*>(from)); +} + +// Unaligned store +template <> +EIGEN_STRONG_INLINE void pstoreu<QInt8>(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu<QUInt8>(QUInt8* to, const Packet32q8u& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu<QInt32>(QInt32* to, const Packet8q32i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.val); +} + +// Aligned store +template <> +EIGEN_STRONG_INLINE void pstore<QInt32>(QInt32* to, const Packet8q32i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.val); +} +template <> +EIGEN_STRONG_INLINE void pstore<QUInt8>(QUInt8* to, const Packet32q8u& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.val); +} +template <> +EIGEN_STRONG_INLINE void pstore<QInt8>(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.val); +} + +// Extract first element. +template <> +EIGEN_STRONG_INLINE QInt32 pfirst<Packet8q32i>(const Packet8q32i& a) { + return _mm_cvtsi128_si32(_mm256_castsi256_si128(a)); +} +template <> +EIGEN_STRONG_INLINE QUInt8 pfirst<Packet32q8u>(const Packet32q8u& a) { + return static_cast<uint8_t>(_mm256_extract_epi8(a.val, 0)); +} +template <> +EIGEN_STRONG_INLINE QInt8 pfirst<Packet32q8i>(const Packet32q8i& a) { + return _mm256_extract_epi8(a.val, 0); +} + +// Initialize to constant value. +template <> +EIGEN_STRONG_INLINE Packet32q8i pset1<Packet32q8i>(const QInt8& from) { + return _mm256_set1_epi8(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pset1<Packet32q8u>(const QUInt8& from) { + return _mm256_set1_epi8(static_cast<uint8_t>(from.value)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pset1<Packet8q32i>(const QInt32& from) { + return _mm256_set1_epi32(from.value); +} + +// Basic arithmetic packet ops for QInt32. +template <> +EIGEN_STRONG_INLINE Packet8q32i padd<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_add_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i psub<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_sub_epi32(a.val, b.val); +} +// Note: mullo truncates the result to 32 bits. +template <> +EIGEN_STRONG_INLINE Packet8q32i pmul<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_mullo_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pnegate<Packet8q32i>(const Packet8q32i& a) { + return _mm256_sub_epi32(_mm256_setzero_si256(), a.val); +} + +// Min and max. +template <> +EIGEN_STRONG_INLINE Packet8q32i pmin<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_min_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pmax<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_max_epi32(a.val, b.val); +} + +template <> +EIGEN_STRONG_INLINE Packet32q8u pmin<Packet32q8u>(const Packet32q8u& a, + const Packet32q8u& b) { + return _mm256_min_epu8(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pmax<Packet32q8u>(const Packet32q8u& a, + const Packet32q8u& b) { + return _mm256_max_epu8(a.val, b.val); +} + +template <> +EIGEN_STRONG_INLINE Packet32q8i pmin<Packet32q8i>(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_min_epi8(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i pmax<Packet32q8i>(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_max_epi8(a.val, b.val); +} + +// Reductions. +template <> +EIGEN_STRONG_INLINE QInt32 predux_min<Packet8q32i>(const Packet8q32i& a) { + __m256i tmp = _mm256_min_epi32(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return pfirst<Packet8q32i>( + _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, 1))); +} +template <> +EIGEN_STRONG_INLINE QInt32 predux_max<Packet8q32i>(const Packet8q32i& a) { + __m256i tmp = _mm256_max_epi32(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return pfirst<Packet8q32i>( + _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, 1))); +} + +template <> +EIGEN_STRONG_INLINE QUInt8 predux_min<Packet32q8u>(const Packet32q8u& a) { + __m256i tmp = _mm256_min_epu8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_min_epu8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::min(static_cast<uint8_t>(_mm256_extract_epi8(tmp, 0)), + static_cast<uint8_t>(_mm256_extract_epi8(tmp, 1))); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_max<Packet32q8u>(const Packet32q8u& a) { + __m256i tmp = _mm256_max_epu8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_max_epu8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::max(static_cast<uint8_t>(_mm256_extract_epi8(tmp, 0)), + static_cast<uint8_t>(_mm256_extract_epi8(tmp, 1))); +} + +template <> +EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) { + __m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_min_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::min(_mm256_extract_epi8(tmp, 0), _mm256_extract_epi8(tmp, 1)); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) { + __m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_max_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::max(_mm256_extract_epi8(tmp, 0), _mm256_extract_epi8(tmp, 1)); +} + +// Comparisons +template <> +EIGEN_STRONG_INLINE Packet8q32i peq<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_cmpeq_epi32(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i peq<Packet32q8i>(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_cmpeq_epi8(a.val, b.val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u peq<Packet32q8u>(const Packet32q8u& a, + const Packet32q8u& b) { + return _mm256_cmpeq_epi8(a.val, b.val); +} + +// Note: There are no instructions in AVX2 for unsigned lt/gt comparison. +// These are added in AVX-512. +template <> +EIGEN_STRONG_INLINE Packet8q32i ple<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + const __m256i gt = _mm256_cmpgt_epi32(a.val, b.val); + return _mm256_xor_si256(gt, gt); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i ple<Packet32q8i>(const Packet32q8i& a, + const Packet32q8i& b) { + const __m256i gt = _mm256_cmpgt_epi8(a.val, b.val); + return _mm256_xor_si256(gt, gt); +} + +template <> +EIGEN_STRONG_INLINE Packet8q32i plt<Packet8q32i>(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_cmpgt_epi32(b.val, a.val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i plt<Packet32q8i>(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_cmpgt_epi8(b.val, a.val); +} + +// Vectorized scaling of Packet32q8i by float. +template <> +struct functor_traits<scalar_multiple2_op<QInt32, double>> { + enum { Cost = 4 * NumTraits<float>::MulCost, PacketAccess = true }; +}; + +template <> +EIGEN_STRONG_INLINE const Packet8q32i +scalar_multiple2_op<QInt32, double>::packetOp(const Packet8q32i& a) const { + __m256d scale = _mm256_set1_pd(m_other); + __m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a)); + __m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo)); + __m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1)); + __m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, + 1); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h new file mode 100644 index 0000000000..045384d7fc --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h @@ -0,0 +1,66 @@ +#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ +#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ + +namespace Eigen { +namespace internal { + +typedef __m256 Packet8f; + +template <> +struct type_casting_traits<QInt32, float> { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet8f pcast<Packet8q32i>(const Packet8q32i& a) { + return _mm256_cvtepi32_ps(a.val); +} + +template <> +struct type_casting_traits<float, QInt32> { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet8q32i pcast<Packet8f>(const Packet8f& a) { + return _mm256_cvtps_epi32(a); +} + +template <> +struct type_casting_traits<QInt32, QInt8> { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8i +pcast<Packet8q32i, Packet32q8i>(const Packet8q32i& a, const Packet8q32i& b, + const Packet8q32i& c, const Packet8q32i& d) { + __m256i converted = _mm256_packs_epi16(_mm256_packs_epi32(a.val, b.val), + _mm256_packs_epi32(c.val, d.val)); + // Since packs does not cross 128 bit lane boundaries, + // we have to permute to properly order the final result. + const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + return _mm256_permutevar8x32_epi32(converted, permute_mask); +} + +template <> +struct type_casting_traits<QInt32, QUInt8> { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8u +pcast<Packet8q32i, Packet32q8u>(const Packet8q32i& a, const Packet8q32i& b, + const Packet8q32i& c, const Packet8q32i& d) { + const __m256i converted = _mm256_packus_epi16( + _mm256_packs_epi32(a.val, b.val), _mm256_packs_epi32(c.val, d.val)); + // Since packus does not cross 128 bit lane boundaries, + // we have to permute to properly order the final result. + const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + return _mm256_permutevar8x32_epi32(converted, permute_mask); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ |