aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-12-21 15:06:06 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-12-21 15:06:06 -0800
commit519d63d350222ddbed5db1883a8fb2c7aab4b4e9 (patch)
treeaaa27bcab15ef967d1fcf293318e90ba2e696ad7 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent0657228569b26c132bbe9a0016912e7cb0fdc2b0 (diff)
Added support for libxsmm kernel in multithreaded contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h208
1 files changed, 204 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index ee16cde9b..d30cc96ab 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -116,6 +116,28 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const {
+ const Index m = this->m_i_size;
+ const Index n = this->m_j_size;
+ const Index k = this->m_k_size;
+ if (m == 0 || n == 0 || k == 0) return;
+
+#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ if (this->m_can_use_xsmm) {
+ bool transposeA = !this->m_lhs_inner_dim_contiguous;
+ bool transposeB = !this->m_rhs_inner_dim_contiguous;
+ internal::TensorXsmmContractionBlocking<LhsScalar, RhsScalar, Index>
+ blocking(k, m, n, this->m_device.numThreads(), transposeA,
+ transposeB);
+
+ if (blocking.num_threads() == 1) {
+ this->evalGemmXSMM(buffer);
+ } else {
+ ContextXsmm<Alignment>(this, buffer, m, n, k, blocking).run();
+ }
+ return;
+ }
+#endif
+
typedef
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
LhsScalar;
@@ -147,10 +169,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Traits::mr, Traits::nr, false, false>
GebpKernel;
- const Index m = this->m_i_size;
- const Index n = this->m_j_size;
- const Index k = this->m_k_size;
- if (m == 0 || n == 0 || k == 0) return;
+
// Compute a set of algorithm parameters:
// - kernel block sizes (bm, bn, bk)
@@ -1044,6 +1063,187 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
rhsCost.dropMemoryCost();
return cost + lhsCost + rhsCost;
}
+
+#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ template<int Alignment>
+ class ContextXsmm {
+ public:
+ ContextXsmm(const Self* self, Scalar* buffer, Index m, Index n, Index k,
+ const internal::TensorXsmmContractionBlocking<LhsScalar,
+ RhsScalar, Index>& blocking):
+ device(self->m_device),
+ m(m), k(k), n(n),
+ stride_a(blocking.transposeA() ? k : m),
+ stride_b(blocking.transposeB() ? n : k),
+ stride_c(m),
+ bm(blocking.mc()), bk(blocking.kc()), bn(blocking.nc()),
+ blocks_m(blocking.blocks_m()), blocks_k(blocking.blocks_k()),
+ blocks_n(blocking.blocks_n()),
+ copyA(blocking.copyA()), copyB(blocking.copyB()),
+ transposeA(blocking.transposeA()), transposeB(blocking.transposeB()),
+ num_threads(blocking.num_threads()),
+ buffer(buffer),
+ leftData(self->m_leftImpl.data()), rightData(self->m_rightImpl.data()),
+ workers_done(blocking.num_threads()),
+
+ packingA_jobs(0), packingB_jobs(0), compute_jobs(0),
+ packingA_done(blocking.blocks_m()), packingB_done(blocking.blocks_n()) {}
+
+ void worker() {
+ // Pack
+
+ if (copyA) {
+ while (true) {
+ uint32_t mk = packingA_jobs++;
+ Index mi = mk / blocks_k;
+ Index ki = mk % blocks_k;
+ if (mi >= blocks_m) break;
+
+ LhsScalar * blockA = blocksA + (bk*bm) * (mi*blocks_k+ki);
+ if (transposeA) {
+ const LhsScalar * current_a = leftData + (bm*mi)*stride_a + (bk*ki);
+ libxsmm_otrans(blockA, current_a, sizeof(LhsScalar), actual_bk(ki),
+ actual_bm(mi), stride_a, bm);
+ } else {
+ const LhsScalar * current_a = leftData + (bk*ki)*stride_a + (bm*mi);
+ internal::pack_simple<LhsScalar, Index>(blockA, current_a,
+ actual_bk(ki), actual_bm(mi), bm, stride_a);
+ }
+ packingA_done.at(mi)++;
+ }
+ }
+
+ if (copyB) {
+ while (true) {
+ uint32_t nk = packingB_jobs++;
+ Index ni = nk / blocks_k;
+ Index ki = nk % blocks_k;
+ if (ni >= blocks_n) break;
+
+ RhsScalar * blockB = blocksB + (bk*bn) * (ni*blocks_k+ki);
+ if (transposeB) {
+ const RhsScalar * current_b = rightData + (ki*bk)*stride_b +
+ (ni*bn);
+ libxsmm_otrans(blockB, current_b, sizeof(RhsScalar), actual_bn(ni),
+ actual_bk(ki), stride_b, bk);
+ } else {
+ const RhsScalar * current_b = rightData + (ni*bn)*stride_b +
+ (ki*bk);
+ internal::pack_simple<RhsScalar, Index>(blockB, current_b,
+ actual_bn(ni), actual_bk(ki), bk, stride_b);
+ }
+ packingB_done.at(ni)++;
+ }
+ }
+
+ // Compute
+
+ while (true) {
+ uint32_t mn = compute_jobs++;
+ Index mi = mn / blocks_n;
+ Index ni = mn % blocks_n;
+ if (mi >= blocks_m) break;
+
+ // Wait for mi, ni packings to be done. This is more fine-grained than
+ // waiting for all workers to finish packing.
+ while ((copyA && (packingA_done.at(mi) < blocks_k)) ||
+ (copyB && (packingB_done.at(ni) < blocks_k)))
+ {}
+
+ for (Index ki=0; ki < blocks_k; ++ki) {
+ const LhsScalar * current_a = copyA ?
+ blocksA + (bk*bm) * (mi*blocks_k+ki) :
+ leftData + (bk*ki)*stride_a + (bm*mi);
+ const RhsScalar * current_b = copyB ?
+ blocksB + (bk*bn) * (ni*blocks_k+ki) :
+ rightData + (ni*bn)*stride_b + (bk*ki);
+
+ Index current_stride_a = copyA ? bm : stride_a;
+ Index current_stride_b = copyB ? bk : stride_b;
+
+ // Memory may not be zeroed, overwrite instead of adding in first
+ // iteration.
+ float beta = ki == 0 ? 0 : 1;
+
+ Scalar * current_c = buffer + (mi*bm) + (ni*bn)*stride_c;
+ internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(
+ 0, actual_bm(mi), actual_bn(ni), actual_bk(ki),
+ current_stride_a, current_stride_b, stride_c, 1, beta, 0)
+ (current_a, current_b, current_c);
+ }
+ }
+
+ workers_done.Notify();
+ }
+
+ void run() {
+ // Parallelization strategy.
+ //
+ // First pack A into blocks (sharding by m, k) and B (sharding by n,k),
+ // then shard by m, n.
+ //
+ // Do not use advanced ThreadPool queuing, just run a single long-standing
+ // function in each thread.
+ if (copyA) {
+ blocksA = static_cast<LhsScalar*>(device.allocate(
+ (blocks_m*bm)*(blocks_k*bk)*sizeof(LhsScalar)));
+ }
+ if (copyB) {
+ blocksB = static_cast<RhsScalar*>(device.allocate(
+ (blocks_n*bn)*(blocks_k*bk)*sizeof(RhsScalar)));
+ }
+
+ for (Index i = 0; i < num_threads; ++i) {
+ device.enqueueNoNotification([=]() { worker(); });
+ }
+
+ workers_done.Wait();
+
+ if (copyA) {
+ device.deallocate(blocksA);
+ }
+ if (copyB) {
+ device.deallocate(blocksB);
+ }
+ }
+
+ private:
+ // real block size for block index in [0, ..., blocks - 1].
+ Index actual_bm(Index mi) const {
+ return mi != blocks_m - 1 ? bm : m + bm - bm * blocks_m;
+ }
+ Index actual_bk(Index ki) const {
+ return ki != blocks_k - 1 ? bk : k + bk - bk * blocks_k;
+ }
+ Index actual_bn(Index ni) const {
+ return ni != blocks_n - 1 ? bn : n + bn - bn * blocks_n;
+ }
+
+ const Device& device;
+ Index m, k, n;
+ Index stride_a, stride_b, stride_c;
+ Index bm, bk, bn; // Block sizes.
+ Index blocks_m, blocks_k, blocks_n; // Number of blocks in each dimension.
+ bool copyA, copyB, transposeA, transposeB;
+ Index num_threads;
+ Scalar *buffer;
+ const LhsScalar *leftData;
+ const RhsScalar *rightData;
+
+ LhsScalar *blocksA;
+ RhsScalar *blocksB;
+ // barrier for joining all threads after all done.
+ Barrier workers_done;
+ // "queues" of (mi,ki), (ki,ni), (mi,ni) jobs packed [0,p)x[0,q) -> [0, p*q)
+ std::atomic<uint32_t> packingA_jobs;
+ std::atomic<uint32_t> packingB_jobs;
+ std::atomic<uint32_t> compute_jobs;
+ // already packed blocks for each mi-panel in A and ni-panel in B.
+ std::vector<std::atomic<uint8_t>> packingA_done;
+ std::vector<std::atomic<uint8_t>> packingB_done;
+ };
+#endif
+
};
} // end namespace Eigen