diff options
Diffstat (limited to 'third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h')
-rw-r--r-- | third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h index d561b79fbd..6b4b0edcfb 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h @@ -42,39 +42,50 @@ public: // Specialized blocking for quantized implementations. // Used by TensorContractionThreadPool, inputs must have dimensions that are // multiples of 32. -template<int KcFactor, typename Index> -struct ComputeGemmByColBlockingSizes<QInt8, QUInt8, KcFactor, Index> { - void operator()(Index& k, Index& m, Index& n, Index num_threads) +template<typename Index, + typename LeftTensor, + typename left_nocontract_t, typename left_contract_t, + bool left_inner_dim_contiguous, bool left_inner_dim_reordered, int LeftAlignment, + typename RightTensor, + typename right_nocontract_t, typename right_contract_t, + bool right_inner_dim_contiguous, bool right_inner_dim_reordered, int RightAlignment, int ShardingType> +class TensorContractionBlocking<TensorContractionInputMapper<QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor, right_nocontract_t, right_contract_t, 32, right_inner_dim_contiguous, right_inner_dim_reordered, RightAlignment>, Index, ShardingType> { + public: + + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + + TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : + kc_(k), mc_(m), nc_(n) { eigen_assert(m % 32 == 0); - eigen_assert(n % 32 == 0); eigen_assert(k % 32 == 0); if (!k || !m || !n) { return; } - n = (((n / num_threads) + 31) / 32) * 32; - } -}; -// Specialized blocking for quantized implementations. -// Used by TensorContractionThreadPool, inputs must have dimensions that are -// multiples of 32. -template<int KcFactor, typename Index> -struct ComputeGemmByRowBlockingSizes<QInt8, QUInt8, KcFactor, Index> { - void operator()(Index& k, Index& m, Index& n, Index num_threads) - { - eigen_assert(m % 32 == 0); - eigen_assert(n % 32 == 0 || n == 1); - eigen_assert(k % 32 == 0); - if (!k || !m || !n) { - return; + if (ShardingType == ShardByCol) { + eigen_assert(n % 32 == 0); + nc_ = (((n / num_threads) + 31) / 32) * 32; } - // Special case to avoid breaking the unimplemented matrix-vector case - if (n == 1) { - n = 32; + else { + eigen_assert(n % 32 == 0 || n == 1); + // Special case to avoid breaking the unimplemented matrix-vector case + if (n == 1) { + nc_ = 32; + } + mc_ = (((m / num_threads) + 31) / 32) * 32; } - m = (((m / num_threads) + 31) / 32) * 32; } + + EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } + EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } + EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } + + private: + Index kc_; + Index mc_; + Index nc_; }; // Specialized blocking for quantized implementations. |