aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-01-30 15:25:57 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-01-30 15:25:57 -0800
commitfbc39fd02c642119a2c49e517e1cd6e8fa1a008f (patch)
tree6cf2142e4b740eb440c577ca08114e4e24912f91 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent82ce92419e25d8b9902c0f39e2e3b01787bf8687 (diff)
parent63de19c0004933c7b2b1e418292b9f2ae6c138f4 (diff)
Merge latest changes from upstream
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h286
1 files changed, 284 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 1b8017349..828db6d8b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -20,6 +20,70 @@ namespace Eigen {
*
*/
namespace internal {
+#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+template<typename Scalar, typename Index>
+void pack_simple(Scalar * dst, const Scalar * src, Index cols, Index rows, Index lddst, Index ldsrc) {
+ size_t psize = packet_traits<Scalar>::size; // Packet size
+ typedef typename packet_traits<Scalar>::type Packet; // Packet type
+ size_t alignment = psize*sizeof(Scalar); // Needed alignment
+ if (rows % psize == 0 && (lddst*sizeof(Scalar)) % alignment == 0 &&
+ (ldsrc*sizeof(Scalar)) % alignment == 0 &&
+ reinterpret_cast<uintptr_t>(src) % alignment == 0 &&
+ reinterpret_cast<uintptr_t>(dst) % alignment == 0) {
+ // Optimized version using packets
+ size_t num_packets = rows / psize;
+ for (Index col = 0; col < cols; ++col) {
+ EIGEN_ASM_COMMENT("begin pack_simple inner copy");
+ // Unrolled manually 4 times.
+ for (size_t i=0; i < num_packets/4; ++i) {
+ internal::pstore(dst, internal::pload<Packet>(src));
+ dst += psize; src += psize;
+ internal::pstore(dst, internal::pload<Packet>(src));
+ dst += psize; src += psize;
+ internal::pstore(dst, internal::pload<Packet>(src));
+ dst += psize; src += psize;
+ internal::pstore(dst, internal::pload<Packet>(src));
+ dst += psize; src += psize;
+ }
+ for (size_t i=0; i < num_packets%4; ++i) {
+ internal::pstore(dst, internal::pload<Packet>(src));
+ dst += psize; src += psize;
+ }
+ dst += lddst - num_packets*psize;
+ src += ldsrc - num_packets*psize;
+ EIGEN_ASM_COMMENT("end pack_simple inner copy");
+ }
+ } else {
+ // Naive memcpy calls
+ for (Index col = 0; col < cols; ++col) {
+ memcpy(dst + col*lddst, src + col*ldsrc, rows*sizeof(Scalar));
+ }
+ }
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalar>
+ struct libxsmm_wrapper {
+ libxsmm_wrapper() {}
+ libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) {}
+ void operator()(const LhsScalar* a, const RhsScalar* b, Scalar* c) {}
+ void operator()(const LhsScalar* a, const RhsScalar* b, Scalar* c, const LhsScalar* ap, const RhsScalar* bp, const Scalar* cp) {}
+ };
+
+ template<>
+ struct libxsmm_wrapper<float, float, float>: public libxsmm_mmfunction<float> {
+ libxsmm_wrapper(): libxsmm_mmfunction() {}
+ libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) :
+ libxsmm_mmfunction(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch) {}
+ };
+
+ template<>
+ struct libxsmm_wrapper<double, double, double>: public libxsmm_mmfunction<double> {
+ libxsmm_wrapper(): libxsmm_mmfunction() {}
+ libxsmm_wrapper(int flags, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta, int prefetch) :
+ libxsmm_mmfunction(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch) {}
+ };
+#endif
+
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
@@ -317,6 +381,8 @@ struct TensorContractionEvaluatorBase
}
}
+ EnableXSMMIfPossible(eval_op_indices);
+
// If the layout is RowMajor, we need to reverse the m_dimensions
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
@@ -422,6 +488,13 @@ struct TensorContractionEvaluatorBase
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
+ #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ if (m_can_use_xsmm) {
+ evalGemmXSMM(buffer);
+ return;
+ }
+ #endif
+
// columns in left side, rows in right side
const Index k = this->m_k_size;
@@ -538,7 +611,212 @@ struct TensorContractionEvaluatorBase
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; }
- protected:
+protected:
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void EnableXSMMIfPossible(const array<IndexPair<Index>, ContractDims>& eval_op_indices) {
+ m_can_use_xsmm = false;
+
+#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
+ if (!std::is_same<Scalar, LhsScalar>::value ||
+ !std::is_same<Scalar, RhsScalar>::value ||
+ !(std::is_same<Scalar, float>::value ||
+ std::is_same<Scalar, double>::value) ||
+ m_leftImpl.data() == NULL ||
+ m_rightImpl.data() == NULL) {
+ return;
+ }
+
+ // Check if we can use faster matmul algorithms. For contraction to be
+ // equivalent to matmul, we need both lhs and rhs contracting dims sequences
+ // to be either a prefix or suffix of all dims. Also, the order of both
+ // must be the same, so we don't have to do reordering.
+ // For example:
+ // * OK: lhs 4D, rhs 4D, contraction: [(0, 2), (1, 3)]
+ // * BAD: lhs 3D, rhs 3D, contraction: [(1,1)]
+ // * BAD: lhs 3D, rhs 3D, contraction: [(0, 0), (2, 2)]
+ // * BAD: lhs 3D, rhs 3D, contraction: [(0, 2), (1, 1)]
+ // Depending if contraction dims are prefix or suffix of all dims we need to
+ // pre-transpose matrices in matmul algorithm:
+ // lhs: prefix -> transpose, suffix -> no transpose
+ // rhs: prefix -> no transpose, suffix -> transpose
+ // For example, for lhs 2D, rhs 2D, contraction [(1, 0)] is regular,
+ // non-transposed matmul.
+ if (ContractDims == 0) {
+ // This case is totally uninteresting, filter it out to avoid problems
+ // with iterations in further tests.
+ return;
+ }
+
+ // Check if RHS dims list is increasing. LHS already is, so if not, the
+ // order is different and we cannot do matmul.
+ for (int i = 1; i < ContractDims; i++) {
+ if (eval_op_indices[i].second < eval_op_indices[i-1].second) {
+ return;
+ }
+ }
+
+ // Check if no holes.
+ int diff;
+ for (int i = 1; i < ContractDims; i++) {
+ // LHS contract dims are sorted to form an increasing seq.
+ diff = eval_op_indices[i].first - eval_op_indices[i-1].first;
+ if (diff != 1) {
+ return;
+ }
+ // Now we may already assume RHS contract dims seq is increasing too.
+ diff = eval_op_indices[i].second - eval_op_indices[i-1].second;
+ if (diff != 1) {
+ return;
+ }
+ }
+
+ // Check if suffix or prefix.
+ if (eval_op_indices[0].first != 0 &&
+ eval_op_indices[ContractDims-1].first != LDims-1) {
+ return;
+ }
+ if (eval_op_indices[0].second != 0 &&
+ eval_op_indices[ContractDims-1].second != RDims-1) {
+ return;
+ }
+
+ m_can_use_xsmm = true;
+ #endif
+ }
+
+#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ EIGEN_DEVICE_FUNC void evalGemmXSMM(Scalar* buffer) const {
+ // columns in left side, rows in right side
+ const Index k = this->m_k_size;
+
+ // rows in left side
+ const Index m = this->m_i_size;
+
+ // columns in right side
+ const Index n = this->m_j_size;
+
+ const bool transposeA = !m_lhs_inner_dim_contiguous;
+ const bool transposeB = !m_rhs_inner_dim_contiguous;
+
+ typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
+
+ internal::TensorXsmmContractionBlocking<LhsScalar, RhsScalar, Index> blocking(
+ k, m, n, 1, transposeA, transposeB);
+
+ // Outer blocks sizes
+ const Index mc_outer = blocking.outer_m();
+ const Index nc_outer = blocking.outer_n();
+ const Index kc_outer = blocking.outer_k();
+ // Inner blocks sizes
+ const Index mc = blocking.mc();
+ const Index nc = blocking.nc();
+ const Index kc = blocking.kc();
+ // Decisions whether we should copy parts of matrices
+ const bool copyA = blocking.copyA();
+ const bool copyB = blocking.copyB();
+
+ const LhsScalar* leftData = m_leftImpl.data();
+ const RhsScalar* rightData = m_rightImpl.data();
+
+ const libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
+ const libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
+ const libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
+
+ const libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
+ // Use bigger stride to avoid hitting same cache line too often.
+ // This consistently gives +~0.5 Gflops.
+ const libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
+ kc % 32 == 0 ? kc + 16 : kc
+ );
+
+ // Kernel for the general case (not edges)
+ internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel;
+
+ LhsScalar* blockA = NULL;
+ RhsScalar* panelB = NULL;
+
+ if (copyA) {
+ blockA = static_cast<LhsScalar*>(this->m_device.allocate(mc * kc * sizeof(LhsScalar)));
+ }
+ if (copyB) {
+ panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar)));
+ }
+
+ const Index kernel_stride_A = copyA ? stride_blockA : stride_A;
+ const Index kernel_stride_B = copyB ? stride_panelB : stride_B;
+ kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch());
+
+ // Outer blocking
+ for (Index ki_outer = 0; ki_outer < k; ki_outer += kc_outer) {
+ for (Index mi_outer = 0; mi_outer < m; mi_outer += mc_outer) {
+ for (Index ni_outer = 0; ni_outer < n; ni_outer += nc_outer) {
+ using numext::mini;
+
+ Index actual_nc_outer = mini(ni_outer+nc_outer, n) - ni_outer;
+
+ // Inner blocking
+ for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) {
+ const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki;
+ const float beta = ki == 0 ? 0 : 1;
+
+ if (copyB) {
+ if (transposeB) {
+ libxsmm_otrans(panelB, rightData + ki*stride_B + ni_outer, sizeof(RhsScalar), actual_nc_outer, actual_kc, stride_B, stride_panelB);
+ } else {
+ internal::pack_simple<RhsScalar, Index>(panelB, rightData + ni_outer*stride_B + ki, actual_nc_outer, actual_kc, stride_panelB, stride_B);
+ }
+ }
+
+ for (Index mi = mi_outer; mi < mini(mi_outer+mc_outer, m); mi += mc) {
+ const Index actual_mc = mini(mi_outer+mc_outer, mini(mi+mc, m)) - mi;
+
+ const LhsScalar* a = transposeA ? leftData + mi*stride_A + ki :
+ leftData + ki*stride_A + mi;
+
+ if (copyA) {
+ if (transposeA) {
+ libxsmm_otrans(blockA, a, sizeof(LhsScalar), actual_kc, actual_mc, stride_A, stride_blockA);
+ } else {
+ internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A);
+ }
+ }
+ const LhsScalar* actual_a = copyA ? blockA : a;
+
+ for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) {
+ const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni;
+
+ const RhsScalar* b = rightData + ni*stride_B + ki;
+ Scalar* c = buffer + ni*stride_C + mi;
+ const Scalar* cp = c + nc*stride_C;
+
+ const RhsScalar* actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
+ const RhsScalar* bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
+
+ if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) {
+ // Most used, cached kernel.
+ kernel(actual_a, actual_b, c, actual_a, bp, cp);
+ } else {
+ // Edges - use libxsmm kernel cache.
+ internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, kernel_stride_A, kernel_stride_B, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, actual_a, bp, cp);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (copyA) {
+ this->m_device.deallocate(blockA);
+ }
+ if (copyB) {
+ this->m_device.deallocate(panelB);
+ }
+ }
+#endif
+
// Prevent assignment
TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&);
Dimensions m_dimensions;
@@ -564,6 +842,11 @@ struct TensorContractionEvaluatorBase
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
const Device& m_device;
Scalar* m_result;
+
+ /// required for sycl
+ const Indices m_expr_indices;
+
+ bool m_can_use_xsmm;
};
@@ -621,7 +904,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
}
-
};
} // end namespace Eigen