aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h')
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h57
1 files changed, 34 insertions, 23 deletions
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
index d561b79fbd..6b4b0edcfb 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
@@ -42,39 +42,50 @@ public:
// Specialized blocking for quantized implementations.
// Used by TensorContractionThreadPool, inputs must have dimensions that are
// multiples of 32.
-template<int KcFactor, typename Index>
-struct ComputeGemmByColBlockingSizes<QInt8, QUInt8, KcFactor, Index> {
- void operator()(Index& k, Index& m, Index& n, Index num_threads)
+template<typename Index,
+ typename LeftTensor,
+ typename left_nocontract_t, typename left_contract_t,
+ bool left_inner_dim_contiguous, bool left_inner_dim_reordered, int LeftAlignment,
+ typename RightTensor,
+ typename right_nocontract_t, typename right_contract_t,
+ bool right_inner_dim_contiguous, bool right_inner_dim_reordered, int RightAlignment, int ShardingType>
+class TensorContractionBlocking<TensorContractionInputMapper<QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor, right_nocontract_t, right_contract_t, 32, right_inner_dim_contiguous, right_inner_dim_reordered, RightAlignment>, Index, ShardingType> {
+ public:
+
+ typedef QInt8 LhsScalar;
+ typedef QUInt8 RhsScalar;
+
+ TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
+ kc_(k), mc_(m), nc_(n)
{
eigen_assert(m % 32 == 0);
- eigen_assert(n % 32 == 0);
eigen_assert(k % 32 == 0);
if (!k || !m || !n) {
return;
}
- n = (((n / num_threads) + 31) / 32) * 32;
- }
-};
-// Specialized blocking for quantized implementations.
-// Used by TensorContractionThreadPool, inputs must have dimensions that are
-// multiples of 32.
-template<int KcFactor, typename Index>
-struct ComputeGemmByRowBlockingSizes<QInt8, QUInt8, KcFactor, Index> {
- void operator()(Index& k, Index& m, Index& n, Index num_threads)
- {
- eigen_assert(m % 32 == 0);
- eigen_assert(n % 32 == 0 || n == 1);
- eigen_assert(k % 32 == 0);
- if (!k || !m || !n) {
- return;
+ if (ShardingType == ShardByCol) {
+ eigen_assert(n % 32 == 0);
+ nc_ = (((n / num_threads) + 31) / 32) * 32;
}
- // Special case to avoid breaking the unimplemented matrix-vector case
- if (n == 1) {
- n = 32;
+ else {
+ eigen_assert(n % 32 == 0 || n == 1);
+ // Special case to avoid breaking the unimplemented matrix-vector case
+ if (n == 1) {
+ nc_ = 32;
+ }
+ mc_ = (((m / num_threads) + 31) / 32) * 32;
}
- m = (((m / num_threads) + 31) / 32) * 32;
}
+
+ EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
+ EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
+ EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
+
+ private:
+ Index kc_;
+ Index mc_;
+ Index nc_;
};
// Specialized blocking for quantized implementations.