aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint')
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h341
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h255
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h1743
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h95
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h123
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h409
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h66
7 files changed, 3032 insertions, 0 deletions
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
new file mode 100644
index 0000000000..564729ce48
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
@@ -0,0 +1,341 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_FIXED_POINT_TYPES_H
+#define EIGEN_CXX11_FIXED_POINT_TYPES_H
+
+#include <cmath>
+#include <iostream>
+
+namespace Eigen {
+
+// The mantissa part of the fixed point representation. See
+// go/tensorfixedpoint for details
+struct QInt8;
+struct QUInt8;
+struct QInt16;
+struct QUInt16;
+struct QInt32;
+
+template <>
+struct NumTraits<QInt8> : GenericNumTraits<int8_t> {};
+template <>
+struct NumTraits<QUInt8> : GenericNumTraits<uint8_t> {};
+template <>
+struct NumTraits<QInt16> : GenericNumTraits<int16_t> {};
+template <>
+struct NumTraits<QUInt16> : GenericNumTraits<uint16_t> {};
+template <>
+struct NumTraits<QInt32> : GenericNumTraits<int32_t> {};
+
+namespace internal {
+template <>
+struct scalar_product_traits<QInt32, double> {
+ enum {
+ // Cost = NumTraits<T>::MulCost,
+ Defined = 1
+ };
+ typedef QInt32 ReturnType;
+};
+}
+
+// Wrap the 8bit int into a QInt8 struct instead of using a typedef to prevent
+// the compiler from silently type cast the mantissa into a bigger or a smaller
+// representation.
+struct QInt8 {
+ QInt8() {}
+ QInt8(const int8_t v) : value(v) {}
+ QInt8(const QInt32 v);
+
+ operator int() const { return static_cast<int>(value); }
+
+ int8_t value;
+};
+
+struct QUInt8 {
+ QUInt8() {}
+ QUInt8(const uint8_t v) : value(v) {}
+ QUInt8(const QInt32 v);
+
+ operator int() const { return static_cast<int>(value); }
+
+ uint8_t value;
+};
+
+struct QInt16 {
+ QInt16() {}
+ QInt16(const int16_t v) : value(v) {}
+ QInt16(const QInt32 v);
+ operator int() const { return static_cast<int>(value); }
+
+ int16_t value;
+};
+
+struct QUInt16 {
+ QUInt16() {}
+ QUInt16(const uint16_t v) : value(v) {}
+ QUInt16(const QInt32 v);
+ operator int() const { return static_cast<int>(value); }
+
+ uint16_t value;
+};
+
+struct QInt32 {
+ QInt32() {}
+ QInt32(const int8_t v) : value(v) {}
+ QInt32(const int32_t v) : value(v) {}
+ QInt32(const QInt8 v) : value(v.value) {}
+ QInt32(const float v) : value(static_cast<int32_t>(lrint(v))) {}
+#ifdef EIGEN_MAKING_DOCS
+ // Workaround to fix build on PPC.
+ QInt32(unsigned long v) : value(v) {}
+#endif
+
+ operator float() const { return static_cast<float>(value); }
+
+ int32_t value;
+};
+
+EIGEN_STRONG_INLINE QInt8::QInt8(const QInt32 v)
+ : value(v.value > 127 ? 127 : (v.value < -128 ? -128 : v.value)) {}
+EIGEN_STRONG_INLINE QUInt8::QUInt8(const QInt32 v)
+ : value(v.value > 255 ? 255 : (v.value < 0 ? 0 : v.value)) {}
+EIGEN_STRONG_INLINE QInt16::QInt16(const QInt32 v)
+ : value(v.value > 32767 ? 32767 : (v.value < -32768 ? -32768 : v.value)) {}
+EIGEN_STRONG_INLINE QUInt16::QUInt16(const QInt32 v)
+ : value(v.value > 65535 ? 65535 : (v.value < 0 ? 0 : v.value)) {}
+
+// Basic widening 8-bit operations: This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt8 b) {
+ return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QUInt8 b) {
+ return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt8 b) {
+ return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt8 b) {
+ return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
+}
+
+// Basic widening 16-bit operations: This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt16 b) {
+ return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QUInt16 b) {
+ return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt16 b) {
+ return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt16 b) {
+ return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
+}
+
+// Mixed QInt32 op QInt8 operations. This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt8 b) {
+ return QInt32(a.value + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) + b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt8 b) {
+ return QInt32(a.value - static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) - b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt8 b) {
+ return QInt32(a.value * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) * b.value);
+}
+
+// Mixed QInt32 op QInt16 operations. This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt16 b) {
+ return QInt32(a.value + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) + b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt16 b) {
+ return QInt32(a.value - static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) - b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt16 b) {
+ return QInt32(a.value * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) * b.value);
+}
+
+// Mixed QInt32 op QUInt8 operations. This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt8 b) {
+ return QInt32(a.value + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QUInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) + b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt8 b) {
+ return QInt32(a.value - static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QUInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) - b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt8 b) {
+ return QInt32(a.value * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QUInt8 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) * b.value);
+}
+
+// Mixed QInt32 op QUInt16 operations. This will be vectorized in future CLs.
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt16 b) {
+ return QInt32(a.value + static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator+(const QUInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) + b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt16 b) {
+ return QInt32(a.value - static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QUInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) - b.value);
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt16 b) {
+ return QInt32(a.value * static_cast<int32_t>(b.value));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QUInt16 a, const QInt32 b) {
+ return QInt32(static_cast<int32_t>(a.value) * b.value);
+}
+
+// Basic arithmetic operations on QInt32, which behaves like a int32_t.
+EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt32 b) {
+ return a.value + b.value;
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt32 b) {
+ return a.value - b.value;
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt32 b) {
+ return a.value * b.value;
+}
+EIGEN_STRONG_INLINE QInt32 operator/(const QInt32 a, const QInt32 b) {
+ return a.value / b.value;
+}
+EIGEN_STRONG_INLINE QInt32& operator+=(QInt32& a, const QInt32 b) {
+ a.value += b.value;
+ return a;
+}
+EIGEN_STRONG_INLINE QInt32& operator-=(QInt32& a, const QInt32 b) {
+ a.value -= b.value;
+ return a;
+}
+EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const QInt32 b) {
+ a.value *= b.value;
+ return a;
+}
+EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) {
+ a.value /= b.value;
+ return a;
+}
+EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) {
+ return -a.value;
+}
+
+// Scaling QInt32 by double. We do the arithmetic in double because
+// float only has 23 bits of mantissa, so casting QInt32 to float might reduce
+// accuracy by discarding up to 7 (least significant) bits.
+EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const double b) {
+ return static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
+}
+EIGEN_STRONG_INLINE QInt32 operator*(const double a, const QInt32 b) {
+ return static_cast<int32_t>(lrint(a * static_cast<double>(b.value)));
+}
+EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const double b) {
+ a.value = static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
+ return a;
+}
+
+// Comparisons
+EIGEN_STRONG_INLINE bool operator==(const QInt8 a, const QInt8 b) {
+ return a.value == b.value;
+}
+EIGEN_STRONG_INLINE bool operator==(const QUInt8 a, const QUInt8 b) {
+ return a.value == b.value;
+}
+EIGEN_STRONG_INLINE bool operator==(const QInt16 a, const QInt16 b) {
+ return a.value == b.value;
+}
+EIGEN_STRONG_INLINE bool operator==(const QUInt16 a, const QUInt16 b) {
+ return a.value == b.value;
+}
+EIGEN_STRONG_INLINE bool operator==(const QInt32 a, const QInt32 b) {
+ return a.value == b.value;
+}
+
+EIGEN_STRONG_INLINE bool operator<(const QInt8 a, const QInt8 b) {
+ return a.value < b.value;
+}
+EIGEN_STRONG_INLINE bool operator<(const QUInt8 a, const QUInt8 b) {
+ return a.value < b.value;
+}
+EIGEN_STRONG_INLINE bool operator<(const QInt16 a, const QInt16 b) {
+ return a.value < b.value;
+}
+EIGEN_STRONG_INLINE bool operator<(const QUInt16 a, const QUInt16 b) {
+ return a.value < b.value;
+}
+EIGEN_STRONG_INLINE bool operator<(const QInt32 a, const QInt32 b) {
+ return a.value < b.value;
+}
+
+EIGEN_STRONG_INLINE bool operator>(const QInt8 a, const QInt8 b) {
+ return a.value > b.value;
+}
+EIGEN_STRONG_INLINE bool operator>(const QUInt8 a, const QUInt8 b) {
+ return a.value > b.value;
+}
+EIGEN_STRONG_INLINE bool operator>(const QInt16 a, const QInt16 b) {
+ return a.value > b.value;
+}
+EIGEN_STRONG_INLINE bool operator>(const QUInt16 a, const QUInt16 b) {
+ return a.value > b.value;
+}
+EIGEN_STRONG_INLINE bool operator>(const QInt32 a, const QInt32 b) {
+ return a.value > b.value;
+}
+
+EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt8 a) {
+ os << static_cast<int>(a.value);
+ return os;
+}
+EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt8 a) {
+ os << static_cast<int>(a.value);
+ return os;
+}
+EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt16 a) {
+ os << static_cast<int>(a.value);
+ return os;
+}
+EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt16 a) {
+ os << static_cast<int>(a.value);
+ return os;
+}
+EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt32 a) {
+ os << a.value;
+ return os;
+}
+
+} // namespace Eigen
+
+#endif // EIGEN_CXX11_FIXED_POINT_TYPES_H
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
new file mode 100644
index 0000000000..4d0dca07df
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
@@ -0,0 +1,255 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
+#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
+
+
+namespace Eigen {
+namespace internal {
+
+// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent
+// overflows
+template<> struct scalar_product_traits<QInt8, QInt8>
+{
+ enum {
+ Defined = 1
+ };
+ typedef QInt32 ReturnType;
+};
+
+// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits
+// to prevent overflows
+template<> struct scalar_product_traits<QInt8, QUInt8>
+{
+ enum {
+ Defined = 1
+ };
+ typedef QInt32 ReturnType;
+};
+
+// Description of the product implementation. It's pretty simple now since
+// nothing is vectorized yet.
+// This definition tackle the case where both lhs and rhs are encoded using
+// signed 8bit integers
+#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT
+
+template<bool _ConjLhs, bool _ConjRhs>
+class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs>
+{
+public:
+ typedef QInt8 LhsScalar;
+ typedef QInt8 RhsScalar;
+ typedef QInt32 ResScalar;
+
+ enum {
+ // register block size along the M and N directions
+ // One for the current implementation
+ nr = 1,
+ mr = 1,
+ // Progress made at each iteration of the product loop
+ // also 1 for the current implementation
+ LhsProgress = 1,
+ RhsProgress = 1
+ };
+};
+
+// The signed 8bit Mat-Mat product itself.
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ for (Index j = 0; j < cols; ++j) {
+ Index startB = j * depth;
+
+ for (Index i = 0; i < rows; ++i) {
+ Index startA = i * depth;
+
+ for (Index k = 0; k < depth; ++k) {
+ res(i, j) += blockA[startA + k] * blockB[startB + k];
+ }
+ }
+ }
+}
+#endif
+
+
+// This definition tackle the case where the lhs is encoded using signed 8bit
+// integers and the rhs using unsigned 8bit integers.
+#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
+template<bool _ConjLhs, bool _ConjRhs>
+class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
+{
+public:
+ typedef QInt8 LhsScalar;
+ typedef QUInt8 RhsScalar;
+ typedef QInt32 ResScalar;
+
+ enum {
+ // register block size along the M and N directions
+ // One for the current implementation
+ nr = 1,
+ mr = 1,
+ // Progress made at each iteration of the product loop
+ // also 1 for the current implementation
+ LhsProgress = 1,
+ RhsProgress = 1
+ };
+};
+
+// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ for (Index j = 0; j < cols; ++j) {
+ Index startB = j * depth;
+
+ for (Index i = 0; i < rows; ++i) {
+ Index startA = i * depth;
+
+ for (Index k = 0; k < depth; ++k) {
+ res(i, j) += blockA[startA + k] * blockB[startB + k];
+ }
+ }
+ }
+}
+#endif
+
+
+// This definition tackle the case where the khs is encoded using unsigned 8bit
+// integers and the rhs using signed 8bit integers.
+#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT
+template<bool _ConjLhs, bool _ConjRhs>
+class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs>
+{
+public:
+ typedef QUInt8 LhsScalar;
+ typedef QInt8 RhsScalar;
+ typedef QInt32 ResScalar;
+
+ enum {
+ // register block size along the M and N directions
+ // One for the current implementation
+ nr = 1,
+ mr = 1,
+ // Progress made at each iteration of the product loop
+ // also 1 for the current implementation
+ LhsProgress = 1,
+ RhsProgress = 1
+ };
+};
+
+
+// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ for (Index j = 0; j < cols; ++j) {
+ Index startB = j * depth;
+
+ for (Index i = 0; i < rows; ++i) {
+ Index startA = i * depth;
+
+ for (Index k = 0; k < depth; ++k) {
+ res(i, j) += blockA[startA + k] * blockB[startB + k];
+ }
+ }
+ }
+}
+#endif
+
+} // namespace internal
+} // namespace Eigen
+
+
+
+#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
new file mode 100644
index 0000000000..d561b79fbd
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
@@ -0,0 +1,1743 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2015 Matthew Sarett <msarett@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H
+#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H
+
+namespace Eigen {
+namespace internal {
+
+// AVX2 optimized implementation of Mat-Mat product.
+// LHS is encoded using signed 8-bit integers.
+// RHS is encoded using unsigned 8-bit integers.
+#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
+
+// Define quantized traits
+template<bool _ConjLhs, bool _ConjRhs>
+class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
+{
+public:
+ typedef QInt8 LhsScalar;
+ typedef QUInt8 RhsScalar;
+ typedef QInt32 ResScalar;
+
+ enum {
+ // Define register blocking scheme.
+ nr = 32,
+ mr = 32,
+ kr = 8,
+ // Ignore progress tracking per loop iteration.
+ LhsProgress = -1,
+ RhsProgress = -1
+ };
+};
+
+// Specialized blocking for quantized implementations.
+// Used by TensorContractionThreadPool, inputs must have dimensions that are
+// multiples of 32.
+template<int KcFactor, typename Index>
+struct ComputeGemmByColBlockingSizes<QInt8, QUInt8, KcFactor, Index> {
+ void operator()(Index& k, Index& m, Index& n, Index num_threads)
+ {
+ eigen_assert(m % 32 == 0);
+ eigen_assert(n % 32 == 0);
+ eigen_assert(k % 32 == 0);
+ if (!k || !m || !n) {
+ return;
+ }
+ n = (((n / num_threads) + 31) / 32) * 32;
+ }
+};
+
+// Specialized blocking for quantized implementations.
+// Used by TensorContractionThreadPool, inputs must have dimensions that are
+// multiples of 32.
+template<int KcFactor, typename Index>
+struct ComputeGemmByRowBlockingSizes<QInt8, QUInt8, KcFactor, Index> {
+ void operator()(Index& k, Index& m, Index& n, Index num_threads)
+ {
+ eigen_assert(m % 32 == 0);
+ eigen_assert(n % 32 == 0 || n == 1);
+ eigen_assert(k % 32 == 0);
+ if (!k || !m || !n) {
+ return;
+ }
+ // Special case to avoid breaking the unimplemented matrix-vector case
+ if (n == 1) {
+ n = 32;
+ }
+ m = (((m / num_threads) + 31) / 32) * 32;
+ }
+};
+
+// Specialized blocking for quantized implementations.
+// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to
+// multiples of 32.
+template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
+class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth,
+ KcFactor, false>
+ : public level3_blocking<QInt8, QInt8> {
+ DenseIndex m_sizeA;
+ DenseIndex m_sizeB;
+
+ public:
+ gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
+ DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
+ this->m_mc = ((rows + 31) / 32) * 32;
+ this->m_nc = ((cols + 31) / 32) * 32;
+ this->m_kc = ((depth + 31) / 32) * 32;
+ m_sizeA = this->m_mc * this->m_kc;
+ m_sizeB = this->m_kc * this->m_nc;
+ }
+ void allocateA() {
+ if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
+ }
+ void allocateB() {
+ if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt8>(m_sizeB);
+ }
+ void allocateAll() {
+ allocateA();
+ allocateB();
+ }
+ ~gemm_blocking_space() {
+ aligned_delete(this->m_blockA, m_sizeA);
+ aligned_delete(this->m_blockB, m_sizeB);
+ }
+};
+
+
+template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
+class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth,
+ KcFactor, false>
+ : public level3_blocking<QInt8, QUInt8> {
+ DenseIndex m_sizeA;
+ DenseIndex m_sizeB;
+
+ public:
+ gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
+ DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
+ this->m_mc = ((rows + 31) / 32) * 32;
+ this->m_nc = ((cols + 31) / 32) * 32;
+ this->m_kc = ((depth + 31) / 32) * 32;
+ m_sizeA = this->m_mc * this->m_kc;
+ m_sizeB = this->m_kc * this->m_nc;
+ }
+ void allocateA() {
+ if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
+ }
+ void allocateB() {
+ if (this->m_blockB == 0) this->m_blockB = aligned_new<QUInt8>(m_sizeB);
+ }
+ void allocateAll() {
+ allocateA();
+ allocateB();
+ }
+ ~gemm_blocking_space() {
+ aligned_delete(this->m_blockA, m_sizeA);
+ aligned_delete(this->m_blockB, m_sizeB);
+ }
+};
+
+// Alternate templates for any input sizes
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
+struct gemm_pack_lhs_any;
+template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> {
+ EIGEN_DONT_INLINE void operator()
+ (QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
+};
+
+template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
+struct gemm_pack_rhs_any;
+template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
+ EIGEN_DONT_INLINE void operator()
+ (QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
+};
+
+template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
+struct gebp_kernel_any;
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename DataMapper::LinearMapper LinearMapper;
+
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+// Alternate implementations for any input sizes
+template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode>::
+operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ // Get vector pointer
+ __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
+
+ // Get even multiples of the dimensions
+ Index rows_32 = (rows / 32) * 32;
+ Index depth_8 = (depth / 8) * 8;
+
+ // Get padding for when depth is not a multiple of 32
+ int padding = 0;
+ if (depth % 32 != 0) {
+ int depth_32 = (depth / 32) * 32;
+ int extra_depth = depth - depth_32;
+ int extra_depth_8 = ((extra_depth + 7) / 8) * 8;
+ padding = 32 - extra_depth_8;
+ }
+
+ // Pack rows in sets of 32
+ for (Index m = 0; m < rows_32; m += 32) {
+ // Pack depth in sets of 8
+ for (Index k = 0; k < depth_8; k += 8) {
+ // Load vectors
+ __m256i L_A = lhs.loadPacket(m, k);
+ __m256i L_B = lhs.loadPacket(m, k + 1);
+
+ // Interleave 8-bit elements
+ __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
+ __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
+
+ __m256i L_C = lhs.loadPacket(m, k + 2);
+ __m256i L_D = lhs.loadPacket(m, k + 3);
+ __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
+ __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
+
+ // Interleave 16-bit elements
+ __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
+ __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
+
+ // Use permute before we store to cross 128-bit lanes
+ __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD0);
+
+ // Complete packing for 32 x 8 block
+ __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
+ __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD8);
+ _mm256_store_si256(blockA_256++, L_AD16);
+ __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
+ _mm256_store_si256(blockA_256++, L_AD24);
+ __m256i L_E = lhs.loadPacket(m, k + 4);
+ __m256i L_F = lhs.loadPacket(m, k + 5);
+ __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
+ __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
+ __m256i L_G = lhs.loadPacket(m, k + 6);
+ __m256i L_H = lhs.loadPacket(m, k + 7);
+ __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
+ __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
+ __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH0);
+ __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
+ __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH8);
+ _mm256_store_si256(blockA_256++, L_EH16);
+ __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
+ _mm256_store_si256(blockA_256++, L_EH24);
+ }
+
+ // Finish the k dimension, padding with zeros
+ if (depth_8 < depth) {
+ __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
+ switch (depth - depth_8) {
+ case 1:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 2:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 3:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = lhs.loadPacket(m, depth_8 + 2);
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 4:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = lhs.loadPacket(m, depth_8 + 2);
+ L_D = lhs.loadPacket(m, depth_8 + 3);
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 5:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = lhs.loadPacket(m, depth_8 + 2);
+ L_D = lhs.loadPacket(m, depth_8 + 3);
+ L_E = lhs.loadPacket(m, depth_8 + 4);
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 6:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = lhs.loadPacket(m, depth_8 + 2);
+ L_D = lhs.loadPacket(m, depth_8 + 3);
+ L_E = lhs.loadPacket(m, depth_8 + 4);
+ L_F = lhs.loadPacket(m, depth_8 + 5);
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ break;
+ case 7:
+ L_A = lhs.loadPacket(m, depth_8);
+ L_B = lhs.loadPacket(m, depth_8 + 1);
+ L_C = lhs.loadPacket(m, depth_8 + 2);
+ L_D = lhs.loadPacket(m, depth_8 + 3);
+ L_E = lhs.loadPacket(m, depth_8 + 4);
+ L_F = lhs.loadPacket(m, depth_8 + 5);
+ L_G = lhs.loadPacket(m, depth_8 + 6);
+ L_H = _mm256_setzero_si256();
+ break;
+ }
+
+ // Interleave 8-bit elements
+ __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
+ __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
+
+ __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
+ __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
+
+ // Interleave 16-bit elements
+ __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
+ __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
+
+ // Use permute before we store to cross 128-bit lanes
+ __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD0);
+
+ // Complete packing
+ __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
+ __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD8);
+ _mm256_store_si256(blockA_256++, L_AD16);
+ __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
+ _mm256_store_si256(blockA_256++, L_AD24);
+ __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
+ __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
+ __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
+ __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
+ __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH0);
+ __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
+ __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH8);
+ _mm256_store_si256(blockA_256++, L_EH16);
+ __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
+ _mm256_store_si256(blockA_256++, L_EH24);
+ }
+ blockA_256 += padding;
+ }
+
+ // Finish the m dimension, padding with zeros
+ if (rows_32 < rows) {
+ // Pack depth in sets of 8
+ for (Index k = 0; k < depth_8; k += 8) {
+ // Load vectors
+ __m256i L_A = _mm256_setzero_si256();
+ __m256i L_B = _mm256_setzero_si256();
+ __m256i L_C = _mm256_setzero_si256();
+ __m256i L_D = _mm256_setzero_si256();
+ __m256i L_E = _mm256_setzero_si256();
+ __m256i L_F = _mm256_setzero_si256();
+ __m256i L_G = _mm256_setzero_si256();
+ __m256i L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ QInt8* ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, k);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, k + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, k + 2);
+ ptr = (QInt8*) &L_D;
+ ptr[m] = lhs(rows_32 + m, k + 3);
+ ptr = (QInt8*) &L_E;
+ ptr[m] = lhs(rows_32 + m, k + 4);
+ ptr = (QInt8*) &L_F;
+ ptr[m] = lhs(rows_32 + m, k + 5);
+ ptr = (QInt8*) &L_G;
+ ptr[m] = lhs(rows_32 + m, k + 6);
+ ptr = (QInt8*) &L_H;
+ ptr[m] = lhs(rows_32 + m, k + 7);
+ }
+
+ // Interleave 8-bit elements
+ __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
+ __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
+ __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
+ __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
+
+ // Interleave 16-bit elements
+ __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
+ __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
+
+ // Use permute before we store to cross 128-bit lanes
+ __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD0);
+
+ // Complete packing for 32 x 8 block
+ __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
+ __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD8);
+ _mm256_store_si256(blockA_256++, L_AD16);
+ __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
+ _mm256_store_si256(blockA_256++, L_AD24);
+ __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
+ __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
+ __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
+ __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
+ __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH0);
+ __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
+ __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH8);
+ _mm256_store_si256(blockA_256++, L_EH16);
+ __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
+ _mm256_store_si256(blockA_256++, L_EH24);
+ }
+
+ // Finish the k dimension, padding with zeros
+ if (depth_8 < depth) {
+ __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
+ QInt8* ptr;
+ switch (depth - depth_8) {
+ case 1:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ QInt8* ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ }
+ break;
+ case 2:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ }
+ break;
+ case 3:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 2);
+ }
+ break;
+ case 4:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 2);
+ ptr = (QInt8*) &L_D;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 3);
+ }
+ break;
+ case 5:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 2);
+ ptr = (QInt8*) &L_D;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 3);
+ ptr = (QInt8*) &L_E;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 4);
+ }
+ break;
+ case 6:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 2);
+ ptr = (QInt8*) &L_D;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 3);
+ ptr = (QInt8*) &L_E;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 4);
+ ptr = (QInt8*) &L_F;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 5);
+ }
+ break;
+ case 7:
+ L_A = _mm256_setzero_si256();
+ L_B = _mm256_setzero_si256();
+ L_C = _mm256_setzero_si256();
+ L_D = _mm256_setzero_si256();
+ L_E = _mm256_setzero_si256();
+ L_F = _mm256_setzero_si256();
+ L_G = _mm256_setzero_si256();
+ L_H = _mm256_setzero_si256();
+ for (Index m = 0; m < rows - rows_32; m++) {
+ ptr = (QInt8*) &L_A;
+ ptr[m] = lhs(rows_32 + m, depth_8);
+ ptr = (QInt8*) &L_B;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 1);
+ ptr = (QInt8*) &L_C;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 2);
+ ptr = (QInt8*) &L_D;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 3);
+ ptr = (QInt8*) &L_E;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 4);
+ ptr = (QInt8*) &L_F;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 5);
+ ptr = (QInt8*) &L_G;
+ ptr[m] = lhs(rows_32 + m, depth_8 + 6);
+ }
+ break;
+ }
+
+ // Interleave 8-bit elements
+ __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
+ __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
+ __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
+ __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
+
+ // Interleave 16-bit elements
+ __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
+ __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
+
+ // Use permute before we store to cross 128-bit lanes
+ __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD0);
+
+ // Complete packing
+ __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
+ __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD8);
+ _mm256_store_si256(blockA_256++, L_AD16);
+ __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
+ _mm256_store_si256(blockA_256++, L_AD24);
+ __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
+ __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
+ __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
+ __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
+ __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH0);
+ __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
+ __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH8);
+ _mm256_store_si256(blockA_256++, L_EH16);
+ __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
+ _mm256_store_si256(blockA_256++, L_EH24);
+ }
+ }
+}
+
+template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::
+operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ // Get vector pointer
+ __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
+
+ // Get even multiples of the dimensions
+ Index cols_32 = (cols / 32) * 32;
+ Index depth_32 = (depth / 32) * 32;
+
+ // Perform a step of the packing for 4 columns
+ __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
+#define PACK_STEP \
+ R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \
+ R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \
+ R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \
+ R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \
+ R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \
+ R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
+ R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \
+ R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
+ _mm256_store_si256(blockB_256, R_AD_0); \
+ _mm256_store_si256(blockB_256 + 8, R_AD_8); \
+ _mm256_store_si256(blockB_256 + 16, R_AD_16); \
+ _mm256_store_si256(blockB_256 + 24, R_AD_24); \
+ blockB_256++;
+
+ // Pack cols in sets of 32
+ for (Index n = 0; n < cols_32; n += 32) {
+ // Pack depth in sets of 32
+ for (Index k = 0; k < depth_32; k += 32) {
+ __m256i R_A = rhs.loadPacket(k, n);
+ __m256i R_B = rhs.loadPacket(k, n + 1);
+ __m256i R_C = rhs.loadPacket(k, n + 2);
+ __m256i R_D = rhs.loadPacket(k, n + 3);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 4);
+ R_B = rhs.loadPacket(k, n + 5);
+ R_C = rhs.loadPacket(k, n + 6);
+ R_D = rhs.loadPacket(k, n + 7);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 8);
+ R_B = rhs.loadPacket(k, n + 9);
+ R_C = rhs.loadPacket(k, n + 10);
+ R_D = rhs.loadPacket(k, n + 11);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 12);
+ R_B = rhs.loadPacket(k, n + 13);
+ R_C = rhs.loadPacket(k, n + 14);
+ R_D = rhs.loadPacket(k, n + 15);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 16);
+ R_B = rhs.loadPacket(k, n + 17);
+ R_C = rhs.loadPacket(k, n + 18);
+ R_D = rhs.loadPacket(k, n + 19);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 20);
+ R_B = rhs.loadPacket(k, n + 21);
+ R_C = rhs.loadPacket(k, n + 22);
+ R_D = rhs.loadPacket(k, n + 23);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 24);
+ R_B = rhs.loadPacket(k, n + 25);
+ R_C = rhs.loadPacket(k, n + 26);
+ R_D = rhs.loadPacket(k, n + 27);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 28);
+ R_B = rhs.loadPacket(k, n + 29);
+ R_C = rhs.loadPacket(k, n + 30);
+ R_D = rhs.loadPacket(k, n + 31);
+ PACK_STEP;
+
+ blockB_256 += 24;
+ }
+
+ if (depth_32 < depth) {
+ QUInt8* ptr;
+ __m256i R_A = _mm256_setzero_si256();
+ __m256i R_B = _mm256_setzero_si256();
+ __m256i R_C = _mm256_setzero_si256();
+ __m256i R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 1);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 2);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 3);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 4);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 5);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 6);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 7);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 8);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 9);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 10);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 11);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 12);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 13);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 14);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 15);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 16);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 17);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 18);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 19);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 20);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 21);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 22);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 23);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 24);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 25);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 26);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 27);
+ }
+ PACK_STEP;
+
+ R_A = _mm256_setzero_si256();
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n + 28);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 29);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 30);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 31);
+ }
+ PACK_STEP;
+ blockB_256 += 24;
+ }
+ }
+
+ // Finish packing cols
+ if (cols_32 < cols) {
+ // Pack depth in sets of 32
+ for (Index k = 0; k < depth_32; k += 32) {
+ __m256i R_A, R_B, R_C, R_D;
+ Index n;
+ for (n = cols_32; n < cols; n += 4) {
+ switch (cols - n) {
+ case 1:
+ R_A = rhs.loadPacket(k, n);
+ R_B = _mm256_setzero_si256();
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ PACK_STEP;
+ break;
+ case 2:
+ R_A = rhs.loadPacket(k, n);
+ R_B = rhs.loadPacket(k, n + 1);
+ R_C = _mm256_setzero_si256();
+ R_D = _mm256_setzero_si256();
+ PACK_STEP;
+ break;
+ case 3:
+ R_A = rhs.loadPacket(k, n);
+ R_B = rhs.loadPacket(k, n + 1);
+ R_C = rhs.loadPacket(k, n + 2);
+ R_D = _mm256_setzero_si256();
+ PACK_STEP;
+ break;
+ default:
+ R_A = rhs.loadPacket(k, n);
+ R_B = rhs.loadPacket(k, n + 1);
+ R_C = rhs.loadPacket(k, n + 2);
+ R_D = rhs.loadPacket(k, n + 3);
+ PACK_STEP;
+ break;
+ }
+ }
+
+ // Increment the block pointer.
+ // We must pad if cols is not a multiple of 32.
+ blockB_256 += 32 - (n - cols_32) / 4;
+ }
+
+ if (depth_32 < depth) {
+ for (Index n = cols_32; n < cols; n += 4) {
+ QUInt8* ptr;
+ __m256i R_A = _mm256_setzero_si256();
+ __m256i R_B = _mm256_setzero_si256();
+ __m256i R_C = _mm256_setzero_si256();
+ __m256i R_D = _mm256_setzero_si256();
+ switch (cols - n) {
+ case 1:
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n);
+ }
+ PACK_STEP;
+ break;
+ case 2:
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 1);
+ }
+ PACK_STEP;
+ break;
+ case 3:
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 1);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 2);
+ }
+ PACK_STEP;
+ break;
+ default:
+ for (Index k = depth_32; k < depth; k++) {
+ ptr = (QUInt8*) &R_A;
+ ptr[k - depth_32] = rhs(k, n);
+ ptr = (QUInt8*) &R_B;
+ ptr[k - depth_32] = rhs(k, n + 1);
+ ptr = (QUInt8*) &R_C;
+ ptr[k - depth_32] = rhs(k, n + 2);
+ ptr = (QUInt8*) &R_D;
+ ptr[k - depth_32] = rhs(k, n + 3);
+ }
+ PACK_STEP;
+ break;
+ }
+ }
+ }
+ }
+#undef PACK_STEP
+}
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ Index rows_32 = ((rows + 31) / 32) * 32;
+ Index cols_32 = ((cols + 31) / 32) * 32;
+ Index depth_32 = ((depth + 31) / 32) * 32;
+
+ // Create result block
+ ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
+ memset(blockO, 0, 32 * 32 * sizeof(QInt32));
+
+ // Get vectorized pointers
+ __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
+ const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
+ const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);
+
+ // Loop over blocks of 32 columns
+ for (Index n = 0; n < cols_32; n += 32) {
+ // Reset index into blockA
+ Index indexL = 0;
+ // Loop over blocks of 32 rows
+ for (Index m = 0; m < rows_32; m += 32) {
+ // Reset index into blockB
+ Index indexR = n / 32 * depth_32;
+ // Loop over blocks of 8 on depth
+ for (Index k = 0; k < depth_32; k += 8) {
+ // Load inputs
+ __m256i L_AD0 = blockA_256[indexL++];
+ __m256i L_AD8 = blockA_256[indexL++];
+ __m256i L_AD16 = blockA_256[indexL++];
+ __m256i L_AD24 = blockA_256[indexL++];
+ __m256i L_EH0 = blockA_256[indexL++];
+ __m256i L_EH8 = blockA_256[indexL++];
+ __m256i L_EH16 = blockA_256[indexL++];
+ __m256i L_EH24 = blockA_256[indexL++];
+ __m256i R_AH0 = blockB_256[indexR++];
+ __m256i R_AH4 = blockB_256[indexR++];
+ __m256i R_AH8 = blockB_256[indexR++];
+ __m256i R_AH12 = blockB_256[indexR++];
+ __m256i R_AH16 = blockB_256[indexR++];
+ __m256i R_AH20 = blockB_256[indexR++];
+ __m256i R_AH24 = blockB_256[indexR++];
+ __m256i R_AH28 = blockB_256[indexR++];
+
+ // This constant is used with madd to convert 16 bit to 32 bit
+ const __m256i ONE = _mm256_set1_epi32(0x00010001);
+
+ // Declare variables used in COMPUTE_STEP
+ __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;
+
+#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 1, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 2, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 3, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));
+
+ // Permute and shuffle to copy a single value across the entire vector
+ // Then compute the multiplication
+ __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
+ __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 0);
+ __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 1);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
+ __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 2);
+ __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 3);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 4);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 5);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 6);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 7);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 8);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 9);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 10);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 11);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 12);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 13);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 14);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 15);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 16);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 17);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 18);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 19);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 20);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 21);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 22);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 23);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 24);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 25);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 26);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 27);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 28);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 29);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 30);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 31);
+
+#undef COMPUTE_STEP
+ }
+
+ // Transfer the results to the result matrix.
+ if (m + 32 <= rows && n + 32 <= cols) {
+ Index i = 0;
+ for (Index j = n; j < n + 32; j++) {
+ LinearMapper r0 = res.getLinearMapper(m, j);
+ LinearMapper r1 = res.getLinearMapper(m + 8, j);
+ LinearMapper r2 = res.getLinearMapper(m + 16, j);
+ LinearMapper r3 = res.getLinearMapper(m + 24, j);
+ r0.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0)));
+ r1.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0)));
+ r2.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0)));
+ r3.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0)));
+ }
+ }
+ else {
+ for (Index j = n; j < cols; j++) {
+ for (Index i = m; i < rows; i++) {
+ res(i, j) = blockO[(j - n) * 32 + (i - m)];
+ }
+ }
+ }
+
+ // Zero the result block so it can be reused
+ memset(blockO, 0, 32 * 32 * sizeof(QInt32));
+ }
+ }
+}
+
+// Below are the fully optimized versions that are correct only for sizes that
+// are multiple of 32. It is about a 10% performance benefit to keep these
+// implementations separate.
+
+// Arrange a block of the left input matrix in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ...
+// A1 B1 C1 D1 E1 F1 G1 H1 ...
+// A2 B2 C2 D2 E2 F2 G2 H2 ...
+// A3 B3 C3 D3 E3 F3 G3 H3 ...
+// A4 B4 C4 D4 E4 F4 G4 H4 ...
+// A5 B5 C5 D5 E5 F5 G5 H5 ...
+// A6 B6 C6 D6 E6 F6 G6 H6 ...
+// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A8 ...
+// ...
+//
+// Packing yields output (A0 beside B0 in memory):
+// A0 B0 C0 D0
+// A1 B1 C1 D1
+// A2 B2 C2 D2
+// A3 B3 C3 D3
+// A4 B4 C4 D4
+// A5 B5 C5 D5
+// A6 B6 C6 D6
+// A7 B7 C7 D7
+// ...
+// A31 B31 C31 D31
+// E0 F0 G0 H0
+// E1 F1 G1 H1
+// E2 F2 G2 H2
+// E3 F3 G3 H3
+// E4 F4 G4 H4
+// E5 F5 G5 H5
+// E6 F6 G6 H6
+// E7 F7 G7 H7
+// ...
+//
+// Four elements of the same row are arranged contiguously because maddubs and
+// madd both perform an adjacent addition in the kernel.
+template <typename Index, typename DataMapper, int Pack1, int Pack2,
+ bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
+ Conjugate, PanelMode> {
+ EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
+ Index depth, Index rows, Index stride = 0,
+ Index offset = 0);
+};
+
+template <typename Index, typename DataMapper, int Pack1, int Pack2,
+ bool Conjugate, bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2,
+ ColMajor, Conjugate, PanelMode>::
+operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
+ Index stride, Index offset) {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ // Use alternate function for weird sizes
+ if (rows % 32 != 0 || depth % 32 != 0) {
+ gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> lhs_pack;
+ return lhs_pack(blockA, lhs, depth, rows, stride, offset);
+ }
+
+ // Get vector pointer
+ __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
+
+ // Pack rows in sets of 32
+ for (Index m = 0; m < rows; m += 32) {
+ // Pack depth in sets of 8
+ for (Index k = 0; k < depth; k += 8) {
+ // Load vectors
+ __m256i L_A = lhs.loadPacket(m, k);
+ __m256i L_B = lhs.loadPacket(m, k + 1);
+
+ // Interleave 8-bit elements
+ __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
+ __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
+
+ __m256i L_C = lhs.loadPacket(m, k + 2);
+ __m256i L_D = lhs.loadPacket(m, k + 3);
+ __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
+ __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
+
+ // Interleave 16-bit elements
+ __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
+ __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
+
+ // Use permute before we store to cross 128-bit lanes
+ __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD0);
+
+ // Complete packing for 32 x 8 block
+ __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
+ __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
+ __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
+ _mm256_store_si256(blockA_256++, L_AD8);
+ _mm256_store_si256(blockA_256++, L_AD16);
+ __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
+ _mm256_store_si256(blockA_256++, L_AD24);
+ __m256i L_E = lhs.loadPacket(m, k + 4);
+ __m256i L_F = lhs.loadPacket(m, k + 5);
+ __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
+ __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
+ __m256i L_G = lhs.loadPacket(m, k + 6);
+ __m256i L_H = lhs.loadPacket(m, k + 7);
+ __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
+ __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
+ __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
+ __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH0);
+ __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
+ __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
+ __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
+ _mm256_store_si256(blockA_256++, L_EH8);
+ _mm256_store_si256(blockA_256++, L_EH16);
+ __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
+ _mm256_store_si256(blockA_256++, L_EH24);
+ }
+ }
+}
+
+// Arrange a block of the right input matrix in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ...
+// A1 B1 C1 D1 E1 F1 G1 H1 ...
+// A2 B2 C2 D2 E2 F2 G2 H2 ...
+// A3 B3 C3 D3 E3 F3 G3 H3 ...
+// A4 B4 C4 D4 E4 F4 G4 H4 ...
+// A5 B5 C5 D5 E5 F5 G5 H5 ...
+// A6 B6 C6 D6 E6 F6 G6 H6 ...
+// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A8 ...
+// ...
+//
+// Packing yields row major output (A0 beside A1 in memory):
+// A0 A1 A2 A3 A4 A5 A6 A7
+// B0 B1 B2 B3 B4 B5 B6 B7
+// ...
+//
+// At least four elements of the same col are arranged contiguously because
+// maddubs and madd both perform an adjacent addition in the kernel. We can
+// save work by leaving 8 adjacent elements because kr = 8.
+template <typename Index, typename DataMapper, int nr, bool Conjugate,
+ bool PanelMode>
+struct gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
+ PanelMode> {
+ EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0);
+};
+
+template <typename Index, typename DataMapper, int nr, bool Conjugate,
+ bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor,
+ Conjugate, PanelMode>::
+operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
+ Index stride, Index offset) {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ // Use alternate function for weird sizes
+ if (cols % 32 != 0 || depth % 32 != 0) {
+ gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> rhs_pack;
+ return rhs_pack(blockB, rhs, depth, cols, stride, offset);
+ }
+
+ // Get vector pointer
+ __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
+
+ // Perform a step of the packing for 4 columns
+ __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
+#define PACK_STEP \
+ R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \
+ R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \
+ R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \
+ R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \
+ R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \
+ R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
+ R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \
+ R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
+ _mm256_store_si256(blockB_256, R_AD_0); \
+ _mm256_store_si256(blockB_256 + 8, R_AD_8); \
+ _mm256_store_si256(blockB_256 + 16, R_AD_16); \
+ _mm256_store_si256(blockB_256 + 24, R_AD_24); \
+ blockB_256++;
+
+ // Pack cols in sets of 32
+ for (Index n = 0; n < cols; n += 32) {
+ // Pack depth in sets of 32
+ for (Index k = 0; k < depth; k += 32) {
+ __m256i R_A = rhs.loadPacket(k, n);
+ __m256i R_B = rhs.loadPacket(k, n + 1);
+ __m256i R_C = rhs.loadPacket(k, n + 2);
+ __m256i R_D = rhs.loadPacket(k, n + 3);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 4);
+ R_B = rhs.loadPacket(k, n + 5);
+ R_C = rhs.loadPacket(k, n + 6);
+ R_D = rhs.loadPacket(k, n + 7);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 8);
+ R_B = rhs.loadPacket(k, n + 9);
+ R_C = rhs.loadPacket(k, n + 10);
+ R_D = rhs.loadPacket(k, n + 11);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 12);
+ R_B = rhs.loadPacket(k, n + 13);
+ R_C = rhs.loadPacket(k, n + 14);
+ R_D = rhs.loadPacket(k, n + 15);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 16);
+ R_B = rhs.loadPacket(k, n + 17);
+ R_C = rhs.loadPacket(k, n + 18);
+ R_D = rhs.loadPacket(k, n + 19);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 20);
+ R_B = rhs.loadPacket(k, n + 21);
+ R_C = rhs.loadPacket(k, n + 22);
+ R_D = rhs.loadPacket(k, n + 23);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 24);
+ R_B = rhs.loadPacket(k, n + 25);
+ R_C = rhs.loadPacket(k, n + 26);
+ R_D = rhs.loadPacket(k, n + 27);
+ PACK_STEP;
+
+ R_A = rhs.loadPacket(k, n + 28);
+ R_B = rhs.loadPacket(k, n + 29);
+ R_C = rhs.loadPacket(k, n + 30);
+ R_D = rhs.loadPacket(k, n + 31);
+ PACK_STEP;
+
+ blockB_256 += 24;
+ }
+ }
+#undef PACK_STEP
+}
+
+// Perform the actual multiplication on packed inputs
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename DataMapper::LinearMapper LinearMapper;
+
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ // Use alternate function for weird sizes
+ if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) {
+ gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> gebp;
+ return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+ // Create result block
+ QInt32* blockO = aligned_new<QInt32>(32 * 32);
+ // Allocating the result block is about 5-10% faster than declaring stack
+ // space. It is unclear why this is the case.
+ // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
+ memset(blockO, 0, 32 * 32 * sizeof(QInt32));
+
+ // Get vectorized pointers
+ __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
+ const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
+ const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);
+
+ // Loop over blocks of 32 columns
+ for (Index n = 0; n < cols; n += 32) {
+ // Reset index into blockA
+ Index indexL = 0;
+ // Loop over blocks of 32 rows
+ for (Index m = 0; m < rows; m += 32) {
+ // Reset index into blockB
+ Index indexR = n / 32 * depth;
+ // Loop over blocks of 8 on depth
+ for (Index k = 0; k < depth; k += 8) {
+ // Load inputs
+ __m256i L_AD0 = blockA_256[indexL++];
+ __m256i L_AD8 = blockA_256[indexL++];
+ __m256i L_AD16 = blockA_256[indexL++];
+ __m256i L_AD24 = blockA_256[indexL++];
+ __m256i L_EH0 = blockA_256[indexL++];
+ __m256i L_EH8 = blockA_256[indexL++];
+ __m256i L_EH16 = blockA_256[indexL++];
+ __m256i L_EH24 = blockA_256[indexL++];
+ __m256i R_AH0 = blockB_256[indexR++];
+ __m256i R_AH4 = blockB_256[indexR++];
+ __m256i R_AH8 = blockB_256[indexR++];
+ __m256i R_AH12 = blockB_256[indexR++];
+ __m256i R_AH16 = blockB_256[indexR++];
+ __m256i R_AH20 = blockB_256[indexR++];
+ __m256i R_AH24 = blockB_256[indexR++];
+ __m256i R_AH28 = blockB_256[indexR++];
+
+ // This constant is used with madd to convert 16 bit to 32 bit
+ const __m256i ONE = _mm256_set1_epi32(0x00010001);
+
+ // Declare variables used in COMPUTE_STEP
+ __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;
+
+#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 1, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 2, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
+ \
+ P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \
+ P_32_A = _mm256_madd_epi16(P_16_A, ONE); \
+ P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \
+ P_32_B = _mm256_madd_epi16(P_16_B, ONE); \
+ P_32 = _mm256_add_epi32(P_32_A, P_32_B); \
+ _mm256_store_si256( \
+ blockO_256 + 4 * OFFSET + 3, \
+ _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));
+
+ // Permute and shuffle to copy a single value across the entire vector
+ // Then compute the multiplication
+ __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
+ __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 0);
+ __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 1);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
+ __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 2);
+ __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 3);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 4);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 5);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 6);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 7);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 8);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 9);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 10);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 11);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 12);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 13);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 14);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 15);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 16);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 17);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 18);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 19);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 20);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 21);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 22);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 23);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 24);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 25);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 26);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 27);
+
+ R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
+ R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD0, R_EH0, 28);
+ R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD1, R_EH1, 29);
+ R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
+ R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
+ R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
+ COMPUTE_STEP(R_AD2, R_EH2, 30);
+ R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
+ R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
+ COMPUTE_STEP(R_AD3, R_EH3, 31);
+
+#undef COMPUTE_STEP
+ }
+
+ // Transfer the results to the result matrix
+ Index i = 0;
+ for (Index j = n; j < n + 32; j++) {
+ LinearMapper r0 = res.getLinearMapper(m, j);
+ LinearMapper r1 = res.getLinearMapper(m + 8, j);
+ LinearMapper r2 = res.getLinearMapper(m + 16, j);
+ LinearMapper r3 = res.getLinearMapper(m + 24, j);
+ r0.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0)));
+ r1.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0)));
+ r2.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0)));
+ r3.storePacket(
+ 0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0)));
+ }
+
+ // Zero the result block so it can be reused
+ memset(blockO, 0, 32 * 32 * sizeof(QInt32));
+ }
+ }
+ aligned_delete(blockO, 32 * 32);
+}
+
+#endif // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_AVX2_H
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
new file mode 100644
index 0000000000..99894cafb5
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
@@ -0,0 +1,95 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2015 Benoit Jacob <benoitjacob@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H
+#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H
+
+
+namespace Eigen {
+namespace internal {
+
+
+// AVX2 optimized implementation of the case where the lhs is encoded using signed 8bit
+// integers and the rhs using unsigned 8bit integers.
+#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
+
+template<bool _ConjLhs, bool _ConjRhs>
+class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
+{
+public:
+ typedef QInt8 LhsScalar;
+ typedef QUInt8 RhsScalar;
+ typedef QInt32 ResScalar;
+
+ enum {
+ // register block size along the M and N directions
+ // One for the current implementation
+ nr = 1,
+ mr = 1,
+ // Progress made at each iteration of the product loop
+ // also 1 for the current implementation
+ LhsProgress = 1,
+ RhsProgress = 1
+ };
+};
+
+// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ EIGEN_DONT_INLINE
+ void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+EIGEN_DONT_INLINE
+void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
+ Index rows, Index depth, Index cols, QInt32 alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ eigen_assert(alpha.value == 1);
+ eigen_assert(strideA == -1);
+ eigen_assert(strideB == -1);
+ eigen_assert(offsetA == 0);
+ eigen_assert(offsetB == 0);
+
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+ eigen_assert(depth > 0);
+ eigen_assert(blockA);
+ eigen_assert(blockB);
+
+ for (Index j = 0; j < cols; ++j) {
+ Index startB = j * depth;
+
+ for (Index i = 0; i < rows; ++i) {
+ Index startA = i * depth;
+
+ for (Index k = 0; k < depth; ++k) {
+ res(i, j) += blockA[startA + k] * blockB[startB + k];
+ }
+ }
+ }
+}
+#endif
+
+
+} // namespace internal
+} // namespace Eigen
+
+
+
+#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_NEON_H
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
new file mode 100644
index 0000000000..18b5085b89
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
@@ -0,0 +1,123 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H
+#define EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H
+
+
+namespace Eigen {
+namespace internal {
+
+// Mat-Vec product
+// Both lhs and rhs are encoded as 8bit signed integers
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>
+{
+EIGEN_DONT_INLINE static void run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QInt8 alpha);
+};
+
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QInt8 alpha)
+{
+ eigen_assert(alpha.value == 1);
+ eigen_assert(resIncr == 1);
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+
+ for (Index i = 0; i < rows; ++i) {
+ for (Index j = 0; j < cols; ++j) {
+ res[i] += lhs(i, j) * rhs(j, 0);
+ }
+ }
+}
+
+
+// Mat-Vec product
+// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned integers
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>
+{
+EIGEN_DONT_INLINE static void run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QUInt8 alpha);
+};
+
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>::run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QUInt8 alpha)
+{
+ eigen_assert(alpha.value == 1);
+ eigen_assert(resIncr == 1);
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+
+ for (Index i = 0; i < rows; ++i) {
+ for (Index j = 0; j < cols; ++j) {
+ res[i] += lhs(i, j) * rhs(j, 0);
+ }
+ }
+}
+
+
+// Mat-Vec product
+// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed integers
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+struct general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>
+{
+EIGEN_DONT_INLINE static void run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QInt8 alpha);
+};
+
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
+EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run(
+ Index rows, Index cols,
+ const LhsMapper& lhs,
+ const RhsMapper& rhs,
+ QInt32* res, Index resIncr,
+ QInt8 alpha)
+{
+ eigen_assert(alpha.value == 1);
+ eigen_assert(resIncr == 1);
+ eigen_assert(rows > 0);
+ eigen_assert(cols > 0);
+
+ for (Index i = 0; i < rows; ++i) {
+ for (Index j = 0; j < cols; ++j) {
+ res[i] += lhs(i, j) * rhs(j, 0);
+ }
+ }
+}
+
+} // namespace internal
+} // namespace Eigen
+
+
+
+#endif // EIGEN_CXX11_FIXED_POINT_MAT_VEC_PRODUCT_H
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
new file mode 100644
index 0000000000..cae1a0b06d
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
@@ -0,0 +1,409 @@
+#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
+#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
+
+namespace Eigen {
+namespace internal {
+
+typedef struct Packet32q8i {
+ __m256i val;
+ operator __m256i() const { return val; }
+ Packet32q8i();
+ Packet32q8i(__m256i val) : val(val) {}
+} Packet32q8i;
+
+typedef struct Packet32q8u {
+ __m256i val;
+ operator __m256i() const { return val; }
+ Packet32q8u();
+ Packet32q8u(__m256i val) : val(val) {}
+} Packet32q8u;
+
+typedef struct Packet16q8i {
+ __m128i val;
+ operator __m128i() const { return val; }
+ Packet16q8i();
+ Packet16q8i(__m128i val) : val(val) {}
+} Packet16q8i;
+
+typedef struct Packet16q8u {
+ __m128i val;
+ operator __m128i() const { return val; }
+ Packet16q8u();
+ Packet16q8u(__m128i val) : val(val) {}
+} Packet16q8u;
+
+typedef struct Packet8q32i {
+ __m256i val;
+ operator __m256i() const { return val; }
+ Packet8q32i();
+ Packet8q32i(__m256i val) : val(val) {}
+} Packet8q32i;
+
+typedef struct Packet4q32i {
+ __m128i val;
+ operator __m128i() const { return val; }
+ Packet4q32i();
+ Packet4q32i(__m128i val) : val(val) {}
+} Packet4q32i;
+
+template <>
+struct packet_traits<QInt8> : default_packet_traits {
+ typedef Packet32q8i type;
+ typedef Packet16q8i half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 32,
+ };
+ enum {
+ HasAdd = 0,
+ HasSub = 0,
+ HasMul = 0,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 0,
+ HasSetLinear = 0
+ };
+};
+template <>
+struct packet_traits<QUInt8> : default_packet_traits {
+ typedef Packet32q8u type;
+ typedef Packet16q8u half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 32,
+ };
+ enum {
+ HasAdd = 0,
+ HasSub = 0,
+ HasMul = 0,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 0,
+ HasSetLinear = 0
+ };
+};
+template <>
+struct packet_traits<QInt32> : default_packet_traits {
+ typedef Packet8q32i type;
+ typedef Packet4q32i half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ };
+ enum {
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 0,
+ HasSetLinear = 0
+ };
+};
+
+template <>
+struct unpacket_traits<Packet32q8i> {
+ typedef QInt8 type;
+ typedef Packet16q8i half;
+ enum { size = 32 };
+};
+template <>
+struct unpacket_traits<Packet32q8u> {
+ typedef QUInt8 type;
+ typedef Packet16q8u half;
+ enum { size = 32 };
+};
+template <>
+struct unpacket_traits<Packet8q32i> {
+ typedef QInt32 type;
+ typedef Packet4q32i half;
+ enum { size = 8 };
+};
+
+// Unaligned load
+template <>
+EIGEN_STRONG_INLINE Packet32q8i ploadu<Packet32q8i>(const QInt8* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8u ploadu<Packet32q8u>(const QUInt8* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i ploadu<Packet8q32i>(const QInt32* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+
+// Aligned load
+template <>
+EIGEN_STRONG_INLINE Packet32q8i pload<Packet32q8i>(const QInt8* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8u pload<Packet32q8u>(const QUInt8* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pload<Packet8q32i>(const QInt32* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(from));
+}
+
+// Unaligned store
+template <>
+EIGEN_STRONG_INLINE void pstoreu<QInt8>(QInt8* to, const Packet32q8i& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(to), from.val);
+}
+template <>
+EIGEN_STRONG_INLINE void pstoreu<QUInt8>(QUInt8* to, const Packet32q8u& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(to), from.val);
+}
+template <>
+EIGEN_STRONG_INLINE void pstoreu<QInt32>(QInt32* to, const Packet8q32i& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(to), from.val);
+}
+
+// Aligned store
+template <>
+EIGEN_STRONG_INLINE void pstore<QInt32>(QInt32* to, const Packet8q32i& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
+ from.val);
+}
+template <>
+EIGEN_STRONG_INLINE void pstore<QUInt8>(QUInt8* to, const Packet32q8u& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
+ from.val);
+}
+template <>
+EIGEN_STRONG_INLINE void pstore<QInt8>(QInt8* to, const Packet32q8i& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
+ from.val);
+}
+
+// Extract first element.
+template <>
+EIGEN_STRONG_INLINE QInt32 pfirst<Packet8q32i>(const Packet8q32i& a) {
+ return _mm_cvtsi128_si32(_mm256_castsi256_si128(a));
+}
+template <>
+EIGEN_STRONG_INLINE QUInt8 pfirst<Packet32q8u>(const Packet32q8u& a) {
+ return static_cast<uint8_t>(_mm256_extract_epi8(a.val, 0));
+}
+template <>
+EIGEN_STRONG_INLINE QInt8 pfirst<Packet32q8i>(const Packet32q8i& a) {
+ return _mm256_extract_epi8(a.val, 0);
+}
+
+// Initialize to constant value.
+template <>
+EIGEN_STRONG_INLINE Packet32q8i pset1<Packet32q8i>(const QInt8& from) {
+ return _mm256_set1_epi8(from.value);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8u pset1<Packet32q8u>(const QUInt8& from) {
+ return _mm256_set1_epi8(static_cast<uint8_t>(from.value));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pset1<Packet8q32i>(const QInt32& from) {
+ return _mm256_set1_epi32(from.value);
+}
+
+// Basic arithmetic packet ops for QInt32.
+template <>
+EIGEN_STRONG_INLINE Packet8q32i padd<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_add_epi32(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i psub<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_sub_epi32(a.val, b.val);
+}
+// Note: mullo truncates the result to 32 bits.
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pmul<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_mullo_epi32(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pnegate<Packet8q32i>(const Packet8q32i& a) {
+ return _mm256_sub_epi32(_mm256_setzero_si256(), a.val);
+}
+
+// Min and max.
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pmin<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_min_epi32(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pmax<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_max_epi32(a.val, b.val);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32q8u pmin<Packet32q8u>(const Packet32q8u& a,
+ const Packet32q8u& b) {
+ return _mm256_min_epu8(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8u pmax<Packet32q8u>(const Packet32q8u& a,
+ const Packet32q8u& b) {
+ return _mm256_max_epu8(a.val, b.val);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32q8i pmin<Packet32q8i>(const Packet32q8i& a,
+ const Packet32q8i& b) {
+ return _mm256_min_epi8(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8i pmax<Packet32q8i>(const Packet32q8i& a,
+ const Packet32q8i& b) {
+ return _mm256_max_epi8(a.val, b.val);
+}
+
+// Reductions.
+template <>
+EIGEN_STRONG_INLINE QInt32 predux_min<Packet8q32i>(const Packet8q32i& a) {
+ __m256i tmp = _mm256_min_epi32(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp =
+ _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return pfirst<Packet8q32i>(
+ _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, 1)));
+}
+template <>
+EIGEN_STRONG_INLINE QInt32 predux_max<Packet8q32i>(const Packet8q32i& a) {
+ __m256i tmp = _mm256_max_epi32(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp =
+ _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return pfirst<Packet8q32i>(
+ _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, 1)));
+}
+
+template <>
+EIGEN_STRONG_INLINE QUInt8 predux_min<Packet32q8u>(const Packet32q8u& a) {
+ __m256i tmp = _mm256_min_epu8(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp =
+ _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ tmp = _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, 1));
+ tmp = _mm256_min_epu8(tmp,
+ _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return std::min(static_cast<uint8_t>(_mm256_extract_epi8(tmp, 0)),
+ static_cast<uint8_t>(_mm256_extract_epi8(tmp, 1)));
+}
+template <>
+EIGEN_STRONG_INLINE QUInt8 predux_max<Packet32q8u>(const Packet32q8u& a) {
+ __m256i tmp = _mm256_max_epu8(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp =
+ _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ tmp = _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, 1));
+ tmp = _mm256_max_epu8(tmp,
+ _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return std::max(static_cast<uint8_t>(_mm256_extract_epi8(tmp, 0)),
+ static_cast<uint8_t>(_mm256_extract_epi8(tmp, 1)));
+}
+
+template <>
+EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) {
+ __m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
+ tmp = _mm256_min_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return std::min(_mm256_extract_epi8(tmp, 0), _mm256_extract_epi8(tmp, 1));
+}
+template <>
+EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) {
+ __m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1));
+ tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
+ tmp = _mm256_max_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
+ return std::max(_mm256_extract_epi8(tmp, 0), _mm256_extract_epi8(tmp, 1));
+}
+
+// Comparisons
+template <>
+EIGEN_STRONG_INLINE Packet8q32i peq<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_cmpeq_epi32(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8i peq<Packet32q8i>(const Packet32q8i& a,
+ const Packet32q8i& b) {
+ return _mm256_cmpeq_epi8(a.val, b.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8u peq<Packet32q8u>(const Packet32q8u& a,
+ const Packet32q8u& b) {
+ return _mm256_cmpeq_epi8(a.val, b.val);
+}
+
+// Note: There are no instructions in AVX2 for unsigned lt/gt comparison.
+// These are added in AVX-512.
+template <>
+EIGEN_STRONG_INLINE Packet8q32i ple<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ const __m256i gt = _mm256_cmpgt_epi32(a.val, b.val);
+ return _mm256_xor_si256(gt, gt);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8i ple<Packet32q8i>(const Packet32q8i& a,
+ const Packet32q8i& b) {
+ const __m256i gt = _mm256_cmpgt_epi8(a.val, b.val);
+ return _mm256_xor_si256(gt, gt);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8q32i plt<Packet8q32i>(const Packet8q32i& a,
+ const Packet8q32i& b) {
+ return _mm256_cmpgt_epi32(b.val, a.val);
+}
+template <>
+EIGEN_STRONG_INLINE Packet32q8i plt<Packet32q8i>(const Packet32q8i& a,
+ const Packet32q8i& b) {
+ return _mm256_cmpgt_epi8(b.val, a.val);
+}
+
+// Vectorized scaling of Packet32q8i by float.
+template <>
+struct functor_traits<scalar_multiple2_op<QInt32, double>> {
+ enum { Cost = 4 * NumTraits<float>::MulCost, PacketAccess = true };
+};
+
+template <>
+EIGEN_STRONG_INLINE const Packet8q32i
+scalar_multiple2_op<QInt32, double>::packetOp(const Packet8q32i& a) const {
+ __m256d scale = _mm256_set1_pd(m_other);
+ __m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a));
+ __m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo));
+ __m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1));
+ __m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi,
+ 1);
+}
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
new file mode 100644
index 0000000000..045384d7fc
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
@@ -0,0 +1,66 @@
+#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
+#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
+
+namespace Eigen {
+namespace internal {
+
+typedef __m256 Packet8f;
+
+template <>
+struct type_casting_traits<QInt32, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pcast<Packet8q32i>(const Packet8q32i& a) {
+ return _mm256_cvtepi32_ps(a.val);
+}
+
+template <>
+struct type_casting_traits<float, QInt32> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet8q32i pcast<Packet8f>(const Packet8f& a) {
+ return _mm256_cvtps_epi32(a);
+}
+
+template <>
+struct type_casting_traits<QInt32, QInt8> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet32q8i
+pcast<Packet8q32i, Packet32q8i>(const Packet8q32i& a, const Packet8q32i& b,
+ const Packet8q32i& c, const Packet8q32i& d) {
+ __m256i converted = _mm256_packs_epi16(_mm256_packs_epi32(a.val, b.val),
+ _mm256_packs_epi32(c.val, d.val));
+ // Since packs does not cross 128 bit lane boundaries,
+ // we have to permute to properly order the final result.
+ const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+ return _mm256_permutevar8x32_epi32(converted, permute_mask);
+}
+
+template <>
+struct type_casting_traits<QInt32, QUInt8> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet32q8u
+pcast<Packet8q32i, Packet32q8u>(const Packet8q32i& a, const Packet8q32i& b,
+ const Packet8q32i& c, const Packet8q32i& d) {
+ const __m256i converted = _mm256_packus_epi16(
+ _mm256_packs_epi32(a.val, b.val), _mm256_packs_epi32(c.val, d.val));
+ // Since packus does not cross 128 bit lane boundaries,
+ // we have to permute to properly order the final result.
+ const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+ return _mm256_permutevar8x32_epi32(converted, permute_mask);
+}
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_