From 519d63d350222ddbed5db1883a8fb2c7aab4b4e9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 21 Dec 2016 15:06:06 -0800 Subject: Added support for libxsmm kernel in multithreaded contractions --- .../CXX11/src/Tensor/TensorContractionThreadPool.h | 208 ++++++++++++++++++++- 1 file changed, 204 insertions(+), 4 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index ee16cde9b..d30cc96ab 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -116,6 +116,28 @@ struct TensorEvaluator void evalProduct(Scalar* buffer) const { + const Index m = this->m_i_size; + const Index n = this->m_j_size; + const Index k = this->m_k_size; + if (m == 0 || n == 0 || k == 0) return; + +#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + if (this->m_can_use_xsmm) { + bool transposeA = !this->m_lhs_inner_dim_contiguous; + bool transposeB = !this->m_rhs_inner_dim_contiguous; + internal::TensorXsmmContractionBlocking + blocking(k, m, n, this->m_device.numThreads(), transposeA, + transposeB); + + if (blocking.num_threads() == 1) { + this->evalGemmXSMM(buffer); + } else { + ContextXsmm(this, buffer, m, n, k, blocking).run(); + } + return; + } +#endif + typedef typename internal::remove_const::type LhsScalar; @@ -147,10 +169,7 @@ struct TensorEvaluator GebpKernel; - const Index m = this->m_i_size; - const Index n = this->m_j_size; - const Index k = this->m_k_size; - if (m == 0 || n == 0 || k == 0) return; + // Compute a set of algorithm parameters: // - kernel block sizes (bm, bn, bk) @@ -1044,6 +1063,187 @@ struct TensorEvaluator + class ContextXsmm { + public: + ContextXsmm(const Self* self, Scalar* buffer, Index m, Index n, Index k, + const internal::TensorXsmmContractionBlocking& blocking): + device(self->m_device), + m(m), k(k), n(n), + stride_a(blocking.transposeA() ? k : m), + stride_b(blocking.transposeB() ? n : k), + stride_c(m), + bm(blocking.mc()), bk(blocking.kc()), bn(blocking.nc()), + blocks_m(blocking.blocks_m()), blocks_k(blocking.blocks_k()), + blocks_n(blocking.blocks_n()), + copyA(blocking.copyA()), copyB(blocking.copyB()), + transposeA(blocking.transposeA()), transposeB(blocking.transposeB()), + num_threads(blocking.num_threads()), + buffer(buffer), + leftData(self->m_leftImpl.data()), rightData(self->m_rightImpl.data()), + workers_done(blocking.num_threads()), + + packingA_jobs(0), packingB_jobs(0), compute_jobs(0), + packingA_done(blocking.blocks_m()), packingB_done(blocking.blocks_n()) {} + + void worker() { + // Pack + + if (copyA) { + while (true) { + uint32_t mk = packingA_jobs++; + Index mi = mk / blocks_k; + Index ki = mk % blocks_k; + if (mi >= blocks_m) break; + + LhsScalar * blockA = blocksA + (bk*bm) * (mi*blocks_k+ki); + if (transposeA) { + const LhsScalar * current_a = leftData + (bm*mi)*stride_a + (bk*ki); + libxsmm_otrans(blockA, current_a, sizeof(LhsScalar), actual_bk(ki), + actual_bm(mi), stride_a, bm); + } else { + const LhsScalar * current_a = leftData + (bk*ki)*stride_a + (bm*mi); + internal::pack_simple(blockA, current_a, + actual_bk(ki), actual_bm(mi), bm, stride_a); + } + packingA_done.at(mi)++; + } + } + + if (copyB) { + while (true) { + uint32_t nk = packingB_jobs++; + Index ni = nk / blocks_k; + Index ki = nk % blocks_k; + if (ni >= blocks_n) break; + + RhsScalar * blockB = blocksB + (bk*bn) * (ni*blocks_k+ki); + if (transposeB) { + const RhsScalar * current_b = rightData + (ki*bk)*stride_b + + (ni*bn); + libxsmm_otrans(blockB, current_b, sizeof(RhsScalar), actual_bn(ni), + actual_bk(ki), stride_b, bk); + } else { + const RhsScalar * current_b = rightData + (ni*bn)*stride_b + + (ki*bk); + internal::pack_simple(blockB, current_b, + actual_bn(ni), actual_bk(ki), bk, stride_b); + } + packingB_done.at(ni)++; + } + } + + // Compute + + while (true) { + uint32_t mn = compute_jobs++; + Index mi = mn / blocks_n; + Index ni = mn % blocks_n; + if (mi >= blocks_m) break; + + // Wait for mi, ni packings to be done. This is more fine-grained than + // waiting for all workers to finish packing. + while ((copyA && (packingA_done.at(mi) < blocks_k)) || + (copyB && (packingB_done.at(ni) < blocks_k))) + {} + + for (Index ki=0; ki < blocks_k; ++ki) { + const LhsScalar * current_a = copyA ? + blocksA + (bk*bm) * (mi*blocks_k+ki) : + leftData + (bk*ki)*stride_a + (bm*mi); + const RhsScalar * current_b = copyB ? + blocksB + (bk*bn) * (ni*blocks_k+ki) : + rightData + (ni*bn)*stride_b + (bk*ki); + + Index current_stride_a = copyA ? bm : stride_a; + Index current_stride_b = copyB ? bk : stride_b; + + // Memory may not be zeroed, overwrite instead of adding in first + // iteration. + float beta = ki == 0 ? 0 : 1; + + Scalar * current_c = buffer + (mi*bm) + (ni*bn)*stride_c; + internal::libxsmm_wrapper( + 0, actual_bm(mi), actual_bn(ni), actual_bk(ki), + current_stride_a, current_stride_b, stride_c, 1, beta, 0) + (current_a, current_b, current_c); + } + } + + workers_done.Notify(); + } + + void run() { + // Parallelization strategy. + // + // First pack A into blocks (sharding by m, k) and B (sharding by n,k), + // then shard by m, n. + // + // Do not use advanced ThreadPool queuing, just run a single long-standing + // function in each thread. + if (copyA) { + blocksA = static_cast(device.allocate( + (blocks_m*bm)*(blocks_k*bk)*sizeof(LhsScalar))); + } + if (copyB) { + blocksB = static_cast(device.allocate( + (blocks_n*bn)*(blocks_k*bk)*sizeof(RhsScalar))); + } + + for (Index i = 0; i < num_threads; ++i) { + device.enqueueNoNotification([=]() { worker(); }); + } + + workers_done.Wait(); + + if (copyA) { + device.deallocate(blocksA); + } + if (copyB) { + device.deallocate(blocksB); + } + } + + private: + // real block size for block index in [0, ..., blocks - 1]. + Index actual_bm(Index mi) const { + return mi != blocks_m - 1 ? bm : m + bm - bm * blocks_m; + } + Index actual_bk(Index ki) const { + return ki != blocks_k - 1 ? bk : k + bk - bk * blocks_k; + } + Index actual_bn(Index ni) const { + return ni != blocks_n - 1 ? bn : n + bn - bn * blocks_n; + } + + const Device& device; + Index m, k, n; + Index stride_a, stride_b, stride_c; + Index bm, bk, bn; // Block sizes. + Index blocks_m, blocks_k, blocks_n; // Number of blocks in each dimension. + bool copyA, copyB, transposeA, transposeB; + Index num_threads; + Scalar *buffer; + const LhsScalar *leftData; + const RhsScalar *rightData; + + LhsScalar *blocksA; + RhsScalar *blocksB; + // barrier for joining all threads after all done. + Barrier workers_done; + // "queues" of (mi,ki), (ki,ni), (mi,ni) jobs packed [0,p)x[0,q) -> [0, p*q) + std::atomic packingA_jobs; + std::atomic packingB_jobs; + std::atomic compute_jobs; + // already packed blocks for each mi-panel in A and ni-panel in B. + std::vector> packingA_done; + std::vector> packingB_done; + }; +#endif + }; } // end namespace Eigen -- cgit v1.2.3