diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 12:32:06 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 12:32:06 -0800 |
commit | f9eff17e915e270e654287723cea67be495f5c5f (patch) | |
tree | 775eadae5593a88a9e58ab5e980a40f8520ad339 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | |
parent | c19fe5e9ed24923b5c80867b38c9823da13ff76e (diff) |
Leverage libxsmm kernels within signle threaded contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 291 |
1 files changed, 289 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 2ac6abf69..c446ba1af 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -20,6 +20,70 @@ namespace Eigen { * */ namespace internal { +#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) +template<typename Scalar, typename Index> +void pack_simple(Scalar * dst, const Scalar * src, Index cols, Index rows, Index lddst, Index ldsrc) { + size_t psize = packet_traits<Scalar>::size; // Packet size + typedef typename packet_traits<Scalar>::type Packet; // Packet type + size_t alignment = psize*sizeof(Scalar); // Needed alignment + if (rows % psize == 0 && (lddst*sizeof(Scalar)) % alignment == 0 && + (ldsrc*sizeof(Scalar)) % alignment == 0 && + reinterpret_cast<uintptr_t>(src) % alignment == 0 && + reinterpret_cast<uintptr_t>(dst) % alignment == 0) { + // Optimized version using packets + size_t num_packets = rows / psize; + for (Index col = 0; col < cols; ++col) { + EIGEN_ASM_COMMENT("begin pack_simple inner copy"); + // Unrolled manually 4 times. + for (size_t i=0; i < num_packets/4; ++i) { + internal::pstore(dst, internal::pload<Packet>(src)); + dst += psize; src += psize; + internal::pstore(dst, internal::pload<Packet>(src)); + dst += psize; src += psize; + internal::pstore(dst, internal::pload<Packet>(src)); + dst += psize; src += psize; + internal::pstore(dst, internal::pload<Packet>(src)); + dst += psize; src += psize; + } + for (size_t i=0; i < num_packets%4; ++i) { + internal::pstore(dst, internal::pload<Packet>(src)); + dst += psize; src += psize; + } + dst += lddst - num_packets*psize; + src += ldsrc - num_packets*psize; + EIGEN_ASM_COMMENT("end pack_simple inner copy"); + } + } else { + // Naive memcpy calls + for (Index col = 0; col < cols; ++col) { + memcpy(dst + col*lddst, src + col*ldsrc, rows*sizeof(Scalar)); + } + } +} + +template<typename LhsScalar, typename RhsScalar, typename Scalar> + struct libxsmm_wrapper { + libxsmm_wrapper() {} + libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) {} + void operator()(const LhsScalar* a, const RhsScalar* b, Scalar* c) {} + void operator()(const LhsScalar* a, const RhsScalar* b, Scalar* c, const LhsScalar* ap, const RhsScalar* bp, const Scalar* cp) {} + }; + + template<> + struct libxsmm_wrapper<float, float, float>: public libxsmm_mmfunction<float> { + libxsmm_wrapper(): libxsmm_mmfunction() {} + libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) : + libxsmm_mmfunction(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch) {} + }; + + template<> + struct libxsmm_wrapper<double, double, double>: public libxsmm_mmfunction<double> { + libxsmm_wrapper(): libxsmm_mmfunction() {} + libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) : + libxsmm_mmfunction(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch) {} + }; +#endif + template<typename Dimensions, typename LhsXprType, typename RhsXprType> struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > @@ -317,6 +381,8 @@ struct TensorContractionEvaluatorBase } } + EnableXSMMIfPossible(eval_op_indices); + // If the layout is RowMajor, we need to reverse the m_dimensions if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) { for (int i = 0, j = NumDims - 1; i < j; i++, j--) { @@ -422,6 +488,13 @@ struct TensorContractionEvaluatorBase template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { + #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + if (m_can_use_xsmm) { + evalGemmXSMM(buffer); + return; + } + #endif + // columns in left side, rows in right side const Index k = this->m_k_size; @@ -538,7 +611,221 @@ struct TensorContractionEvaluatorBase EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; } - protected: +protected: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void EnableXSMMIfPossible(const array<IndexPair<Index>, ContractDims>& eval_op_indices) { + m_can_use_xsmm = false; + +#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; + if (!std::is_same<Scalar, LhsScalar>::value || + !std::is_same<Scalar, RhsScalar>::value || + !(std::is_same<Scalar, float>::value || + std::is_same<Scalar, double>::value) || + m_leftImpl.data() == NULL || + m_rightImpl.data() == NULL) { + return; + } + + // Check if we can use faster matmul algorithms. For contraction to be + // equivalent to matmul, we need both lhs and rhs contracting dims sequences + // to be either a prefix or suffix of all dims. Also, the order of both + // must be the same, so we don't have to do reordering. + // For example: + // * OK: lhs 4D, rhs 4D, contraction: [(0, 2), (1, 3)] + // * BAD: lhs 3D, rhs 3D, contraction: [(1,1)] + // * BAD: lhs 3D, rhs 3D, contraction: [(0, 0), (2, 2)] + // * BAD: lhs 3D, rhs 3D, contraction: [(0, 2), (1, 1)] + // Depending if contraction dims are prefix or suffix of all dims we need to + // pre-transpose matrices in matmul algorithm: + // lhs: prefix -> transpose, suffix -> no transpose + // rhs: prefix -> no transpose, suffix -> transpose + // For example, for lhs 2D, rhs 2D, contraction [(1, 0)] is regular, + // non-transposed matmul. + if (ContractDims == 0) { + // This case is totally uninteresting, filter it out to avoid problems + // with iterations in further tests. + return; + } + + // Check if RHS dims list is increasing. LHS already is, so if not, the + // order is different and we cannot do matmul. + for (int i = 1; i < ContractDims; i++) { + if (eval_op_indices[i].second < eval_op_indices[i-1].second) { + return; + } + } + + // Check if no holes. + int diff; + for (int i = 1; i < ContractDims; i++) { + // LHS contract dims are sorted to form an increasing seq. + diff = eval_op_indices[i].first - eval_op_indices[i-1].first; + if (diff != 1) { + return; + } + // Now we may already assume RHS contract dims seq is increasing too. + diff = eval_op_indices[i].second - eval_op_indices[i-1].second; + if (diff != 1) { + return; + } + } + + // Check if suffix or prefix. + if (eval_op_indices[0].first != 0 && + eval_op_indices[ContractDims-1].first != LDims-1) { + return; + } + if (eval_op_indices[0].second != 0 && + eval_op_indices[ContractDims-1].second != RDims-1) { + return; + } + + m_can_use_xsmm = true; + #endif + } + +#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + EIGEN_DEVICE_FUNC void evalGemmXSMM(Scalar* buffer) const { + // columns in left side, rows in right side + const Index k = this->m_k_size; + + // rows in left side + const Index m = this->m_i_size; + + // columns in right side + const Index n = this->m_j_size; + + const bool transposeA = !m_lhs_inner_dim_contiguous; + const bool transposeB = !m_rhs_inner_dim_contiguous; + + typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; + + internal::TensorXsmmContractionBlocking<LhsScalar, RhsScalar, Index> blocking( + k, m, n, 1, transposeA, transposeB); + + // Outer blocks sizes + const Index mc_outer = blocking.outer_m(); + const Index nc_outer = blocking.outer_n(); + const Index kc_outer = blocking.outer_k(); + // Inner blocks sizes + const Index mc = blocking.mc(); + const Index nc = blocking.nc(); + const Index kc = blocking.kc(); + // Decisions whether we should copy parts of matrices + const bool copyA = blocking.copyA(); + const bool copyB = blocking.copyB(); + + const LhsScalar* leftData = m_leftImpl.data(); + const RhsScalar* rightData = m_rightImpl.data(); + + libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m); + libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k); + libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m); + + libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc); + // Use bigger stride to avoid hitting same cache line too often. + // This consistently gives +~0.5 Gflops. + libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>( + kc % 32 == 0 ? kc + 16 : kc + ); + + // Kernel for the general case (not edges) + internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel; + + const LhsScalar *ap; + const RhsScalar *bp; + const Scalar *cp; + + LhsScalar* blockA = NULL; + RhsScalar* panelB = NULL; + + if (copyA) { + blockA = static_cast<LhsScalar*>(this->m_device.allocate(mc * kc * sizeof(LhsScalar))); + } + if (copyB) { + panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar))); + } + + Index kernel_stride_A = copyA ? stride_blockA : stride_A; + Index kernel_stride_B = copyB ? stride_panelB : stride_B; + kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch()); + + // Outer blocking + for (Index ki_outer = 0; ki_outer < k; ki_outer += kc_outer) { + for (Index mi_outer = 0; mi_outer < m; mi_outer += mc_outer) { + for (Index ni_outer = 0; ni_outer < n; ni_outer += nc_outer) { + using numext::mini; + + Index actual_nc_outer = mini(ni_outer+nc_outer, n) - ni_outer; + + // Inner blocking + for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) { + const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki; + + if (copyB) { + if (transposeB) { + libxsmm_otrans(panelB, rightData + ki*stride_B + ni_outer, sizeof(RhsScalar), actual_nc_outer, actual_kc, stride_B, stride_panelB); + } else { + internal::pack_simple<RhsScalar, Index>(panelB, rightData + ni_outer*stride_B + ki, actual_nc_outer, actual_kc, stride_panelB, stride_B); + } + } + + for (Index mi = mi_outer; mi < mini(mi_outer+mc_outer, m); mi += mc) { + const Index actual_mc = mini(mi_outer+mc_outer, mini(mi+mc, m)) - mi; + + const LhsScalar * a = transposeA ? leftData + mi*stride_A + ki : + leftData + ki*stride_A + mi; + + if (copyA) { + if (transposeA) { + libxsmm_otrans(blockA, a, sizeof(LhsScalar), actual_kc, actual_mc, stride_A, stride_blockA); + } else { + internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A); + } + } + + for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) { + const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni; + + const RhsScalar * b = rightData + ni*stride_B + ki; + Scalar * c = buffer + ni*stride_C + mi; + cp = c + nc*stride_C; + + const LhsScalar * actual_a = copyA ? blockA : a; + const Index actual_lda = copyA ? stride_blockA : stride_A; + ap = copyA ? blockA : a; + + const RhsScalar * actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b; + const Index actual_ldb = copyB ? stride_panelB : stride_B; + bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B; + + float beta = ki == 0 ? 0 : 1; + + if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) { + // Most used, cached kernel. + kernel(actual_a, actual_b, c, ap, bp, cp); + } else { + // Edges - use libxsmm kernel cache. + internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, actual_lda, actual_ldb, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, ap, bp, cp); + } + } + } + } + } + } + } + + if (copyA) { + this->m_device.deallocate(blockA); + } + if (copyB) { + this->m_device.deallocate(panelB); + } + } +#endif + // Prevent assignment TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&); Dimensions m_dimensions; @@ -567,6 +854,7 @@ struct TensorContractionEvaluatorBase /// required for sycl const Indices m_expr_indices; + bool m_can_use_xsmm; }; @@ -624,7 +912,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); } - }; } // end namespace Eigen |