aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
commitf0b36fb9a405400e82b73ea70097b8ae3cd1095a (patch)
treed3a2903422799257720d2d4989bcd845ab2ae27e /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent619cea94916e7531a839ee0ff657714857921db8 (diff)
evalSubExprsIfNeededAsync + async TensorContractionThreadPool
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h711
1 files changed, 479 insertions, 232 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index ca20038a4..f9d9d6d31 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -73,6 +73,34 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void evalProduct(Scalar* buffer) const {
+ evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
+ }
+
+ template <typename EvalToCallback, int Alignment>
+ void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
+ evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
+ }
+
+ template <typename DoneCallback, int Alignment>
+ void evalProductImpl(Scalar* buffer, DoneCallback done) const {
+ // This function computes a lot of heuristics in multiple steps, and it
+ // also has multiple exit points. To keep it sane, readable and all in one
+ // place, sync/async execution decision is made at runtime at the very end.
+ //
+ // (1) In sync mode we allocate Context on the stack, submit computations
+ // to the device thread pool, and block on a barrier until it is
+ // completed.
+ //
+ // (2) In async mode we allocate Context on the heap, and after all tasks
+ // are finished, we call provided the done callback, and delete a
+ // context from the heap.
+ //
+ // (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
+ // and temporary buffers, requried for executing the tensor contraction.
+ // They are responsible for cleaning it up after contraction is done.
+ static const bool IsEvalInSyncMode =
+ std::is_same<DoneCallback, NoCallback>::value;
+
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
@@ -134,8 +162,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
// We are in the scenario where it is more effective to shard by the
// inner dimension.
- this->template evalShardedByInnerDim<Alignment>(num_threads_by_k,
- buffer);
+ if (IsEvalInSyncMode) {
+ EvalShardedByInnerDimContext<DoneCallback> ctx(
+ this, num_threads_by_k, buffer, m, n, k, std::move(done));
+ ctx.template run<Alignment>();
+ } else {
+ auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>(
+ this, num_threads_by_k, buffer, m, n, k, std::move(done));
+ ctx->template runAsync<Alignment>();
+ }
+
return;
}
@@ -146,6 +182,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (num_threads == 1) {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
Unaligned, (buffer));
+ if (!IsEvalInSyncMode) done();
return;
}
@@ -230,21 +267,89 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// optimization.
if (parallelize_by_sharding_dim_only) parallel_pack = false;
+ // TODO(ezhulnev): With if contexpr we don't need SyncEvalParallelContext.
+ if (IsEvalInSyncMode) {
#define CONTEXT_ARGS \
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
- nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only) \
+ nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
+ NoCallback()) \
.run()
-
- TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
-
+ TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment,
+ CONTEXT_ARGS);
#undef CONTEXT_ARGS
+ } else {
+#define CONTEXT_ARGS \
+ (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
+ nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
+ std::move(done))
+ TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback,
+ Alignment, CONTEXT_ARGS, run());
+#undef CONTEXT_ARGS
+ }
}
- // Context coordinates a single parallel gemm operation.
- template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
- bool rhs_inner_dim_reordered, int Alignment>
- class Context {
+ // ------------------------------------------------------------------------ //
+
+ // Dummy struct to represent an empty DoneCallback.
+
+ struct NoCallback {
+ void operator()() {
+ eigen_assert(false && "NoCallback should never be called");
+ }
+ };
+
+ // ------------------------------------------------------------------------ //
+
+ template <typename DoneCallback, typename Context>
+ class EvalParallelNotification;
+
+ // Synchronous evaluation notification that blocks caller thread in Wait().
+ template <typename Context>
+ class EvalParallelNotification<NoCallback, Context> {
+ public:
+ EvalParallelNotification(Context*, NoCallback) {}
+ void Notify() { done_.Notify(); }
+ void Wait() { done_.Wait(); }
+ private:
+ Eigen::Notification done_;
+ };
+
+ // Asynchronous evaluation notification that does not block in Wait().
+ template <typename DoneCallback, typename Context>
+ class EvalParallelNotification {
+ public:
+ EvalParallelNotification(Context* ctx, DoneCallback done)
+ : ctx_(ctx), done_(std::move(done)) {}
+
+ void Notify() {
+ // Make a copy of done callback, because it will be destructed when we
+ // will delete context in the next line (EvalParallelNotification is a
+ // data member of EvalParallelContext class).
+ DoneCallback done_copy = std::move(done_);
+
+ // Delete parallel evaluation context.
+ delete ctx_;
+
+ // Now safely call the done callback.
+ done_copy();
+ }
+
+ void Wait() {}
+
+ private:
+ Context* ctx_;
+ DoneCallback done_;
+ };
+
+ // Context orchestrates sync/async parallel contraction evaluation. When it is
+ // executed in asynchronous mode, it owns all the shared state that might be
+ // accessible by block packing and kernel tasks.
+
+ template <typename DoneCallback, bool lhs_inner_dim_contiguous,
+ bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
+ int Alignment>
+ class EvalParallelContext {
public:
typedef internal::TensorContractionInputMapper<
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
@@ -267,11 +372,15 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
- Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
- Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
- Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
- bool parallel_pack, bool parallelize_by_sharding_dim_only)
- : device_(self->m_device),
+ EvalParallelContext(const Self* self, int num_threads, Scalar* buffer,
+ Index tm, Index tn, Index tk, Index bm, Index bn,
+ Index bk, Index nm, Index nn, Index nk, Index gm,
+ Index gn, Index nm0, Index nn0, bool shard_by_col,
+ bool parallel_pack,
+ bool parallelize_by_sharding_dim_only,
+ DoneCallback done)
+ : done_(this, std::move(done)),
+ device_(self->m_device),
lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
self->m_i_strides, self->m_left_contracting_strides,
self->m_k_strides),
@@ -299,8 +408,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
gn_(gn),
nm0_(nm0),
nn0_(nn0),
- kernel_(m_, k_, n_, bm_, bk_, bn_)
- {
+ kernel_(m_, k_, n_, bm_, bk_, bn_) {
// These two options are mutually exclusive.
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
@@ -371,7 +479,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
- ~Context() {
+ ~EvalParallelContext() {
for (Index x = 0; x < P; x++) {
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
delete[] state_kernel_[x];
@@ -386,16 +494,28 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
void run() {
// Kick off packing of the first slice.
signal_switch(0, 1);
+
// Wait for overall completion.
- // TODO(dvyukov): this wait can lead to deadlock.
- // If nthreads contractions are concurrently submitted from worker
- // threads, this wait will block all worker threads and the system will
- // deadlock.
+ //
+ // If parallel evaluation is executed in async mode, this is a no-op, and
+ // Wait() will return immediately. In synchronous mode it will block the
+ // caller thread until it will receive notification from last task.
+ //
+ // In async mode, last task when completed will call done callback from
+ // the same thread, and will delete this context.
+ //
+ // TODO(dvyukov): This wait can lead to deadlock if contraction is
+ // evaluated in synchronous mode. If nthreads contractions are
+ // concurrently submitted from worker threads, this wait will block all
+ // worker threads and the system will deadlock.
done_.Wait();
}
private:
- Notification done_;
+ // This notification is specialized on the type of DoneCallback and can be
+ // blocking or non-blocking.
+ EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
+
const Device& device_;
LhsMapper lhs_;
RhsMapper rhs_;
@@ -780,10 +900,344 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
- Context(const Context&) = delete;
- void operator=(const Context&) = delete;
+ EvalParallelContext(const EvalParallelContext&) = delete;
+ void operator=(const EvalParallelContext&) = delete;
+ };
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
+ bool rhs_inner_dim_reordered, int Alignment>
+ using SyncEvalParallelContext =
+ EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
+ rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
+ Alignment>;
+
+ // ------------------------------------------------------------------------ //
+
+ // EvalShardedByInnerDimContext orchestrates sync/async contraction
+ // evaluation, when we shard by inner dimension. When it is executed in
+ // asynchronous mode, it owns all the shared state that might be accessible by
+ // block processing tasks.
+
+ template <typename DoneCallback>
+ struct EvalShardedByInnerDimContext {
+ EvalShardedByInnerDimContext(const Self* evaluator, int num_threads,
+ Scalar* result, Index m, Index n, Index k,
+ DoneCallback done)
+ : evaluator(evaluator),
+ m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
+ m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
+ m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
+ num_threads(num_threads),
+ result(result),
+ m(m),
+ n(n),
+ k(k),
+ done(std::move(done)),
+ buffer_size_bytes(m * n * sizeof(Scalar)),
+ block_size(blockSize(k, num_threads)),
+ num_blocks(divup<Index>(k, block_size)),
+ num_pending_blocks(internal::convert_index<int>(num_blocks)),
+ l0_ranges(divup<Index>(num_blocks, l0_size)),
+ l0_state(l0_ranges),
+ block_buffers(num_blocks) {
+ // Keep count of pending gemm tasks for each l0 range.
+ for (int i = 0; i < l0_ranges; ++i) {
+ const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
+ l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
+ }
+
+ // Allocate temporary buffers for each block.
+ for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
+ Scalar* buf = block_idx == 0
+ ? result
+ : static_cast<Scalar*>(evaluator->m_device.allocate(
+ buffer_size_bytes));
+ block_buffers.emplace_back(buf);
+ }
+ }
+
+ ~EvalShardedByInnerDimContext() {
+ for (Index i = 1; i < num_blocks; ++i) {
+ evaluator->m_device.deallocate(block_buffers[i]);
+ }
+ }
+
+ template <int Alignment>
+ void run() {
+ Barrier barrier(internal::convert_index<int>(num_blocks));
+ for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
+ evaluator->m_device.enqueueNoNotification(
+ [this, block_idx, &barrier]() {
+ Index block_start = block_idx * block_size;
+ Index block_end = block_start + actualBlockSize(block_idx);
+
+ processBlock<Alignment>(block_idx, block_start, block_end);
+ barrier.Notify();
+ });
+ }
+ barrier.Wait();
+
+ // Aggregate partial sums from l0 ranges.
+ aggregateL0Blocks<Alignment>();
+
+ // Apply output kernel.
+ applyOutputKernel();
+ }
+
+ template <int Alignment>
+ void runAsync() {
+ for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
+ evaluator->m_device.enqueueNoNotification([this, block_idx]() {
+ Index block_start = block_idx * block_size;
+ Index block_end = block_start + actualBlockSize(block_idx);
+
+ processBlock<Alignment>(block_idx, block_start, block_end);
+
+ int v = num_pending_blocks.fetch_sub(1);
+ eigen_assert(v >= 1);
+
+ if (v == 1) {
+ // Aggregate partial sums from l0 ranges.
+ aggregateL0Blocks<Alignment>();
+
+ // Apply output kernel.
+ applyOutputKernel();
+
+ // NOTE: If we call `done` callback before deleting this (context),
+ // it might deallocate Self* pointer captured by context, and we'll
+ // fail in destructor trying to deallocate temporary buffers.
+
+ // Move done call back from context before it will be destructed.
+ DoneCallback done_copy = std::move(done);
+
+ // We are confident that we are the last one who touches context.
+ delete this;
+
+ // Now safely call the done callback.
+ done_copy();
+ }
+ });
+ }
+ }
+
+ private:
+ // The underlying GEMM kernel assumes that k is a multiple of
+ // the packet size and subtle breakage occurs if this is violated.
+ static const Index packet_size = internal::packet_traits<RhsScalar>::size;
+
+ const Self* evaluator; // TensorContraction evaluator
+
+ // These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
+ bool m_lhs_inner_dim_contiguous;
+ bool m_rhs_inner_dim_contiguous;
+ bool m_rhs_inner_dim_reordered;
+
+ int num_threads;
+ Scalar* result;
+
+ Index m;
+ Index n;
+ Index k;
+
+ DoneCallback done;
+
+ // ----------------------------------------------------------------------//
+ // Algorithm parameters.
+
+ // We will compute partial results into the buffers of this size.
+ Index buffer_size_bytes;
+
+ Index block_size;
+ Index num_blocks;
+
+ // Keep track of pending tasks when evaluate in async mode.
+ std::atomic<int> num_pending_blocks;
+
+ // We compute partial gemm results in parallel, and to get the final result
+ // we need to add them all together. For the large number of threads (>= 48)
+ // this adds a very expensive sequential step at the end.
+ //
+ // We split the [0, num_blocks) into small ranges, and when a task for the
+ // block finishes its partial gemm computation, it checks if it was the last
+ // gemm in the range, and if so, it will add all blocks of the range.
+ //
+ // After all tasks done, we need to add only these pre-aggregated blocks.
+
+ // For now we use just a single level of ranges to compute pre-aggregated
+ // partial sums, but in general we can use more layers to compute tree
+ // aggregation in parallel and reduce the size of the sequential step.
+ //
+ // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
+ // sense only if number of threads >= ~128?
+ static const Index l0_size = 4;
+ Index l0_ranges;
+
+ // Keep count of pending gemm tasks for each l0 range.
+ MaxSizeVector<std::atomic<int>> l0_state; // [0, l0_ranges)
+
+ // Buffers allocated for each temporary block computation.
+ MaxSizeVector<Scalar*> block_buffers; // [0, num_blocks)
+
+ template <int Alignment>
+ void processBlock(Index block_idx, Index begin, Index end) {
+ Scalar* buf = block_buffers[block_idx];
+ ::memset(buf, 0, buffer_size_bytes);
+
+ TENSOR_CONTRACTION_DISPATCH(
+ evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
+ (buf, begin, end,
+ /*num_threads=*/internal::convert_index<int>(num_blocks)));
+
+ // Check if it was the last task in l0 range.
+ const Index l0_index = block_idx / l0_size;
+ const int v = l0_state[l0_index].fetch_sub(1);
+ eigen_assert(v >= 1);
+
+ // If we processed the last block of the range, we can aggregate all
+ // partial results into the first block of the range.
+ if (v == 1) {
+ const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
+ const Index dst_block_idx = l0_index * l0_size;
+
+ if (rng_size == l0_size) {
+ addAllToBuffer<Alignment>(
+ m * n,
+ /*src_buf0=*/block_buffers[dst_block_idx + 1],
+ /*src_buf1=*/block_buffers[dst_block_idx + 2],
+ /*src_buf2=*/block_buffers[dst_block_idx + 3],
+ /*dst_buf= */ block_buffers[dst_block_idx]);
+ } else {
+ // Aggregate blocks of potentially incomplete last range.
+ for (int i = 1; i < rng_size; ++i) {
+ addToBuffer<Alignment>(m * n,
+ /*src_buf=*/block_buffers[dst_block_idx + i],
+ /*dst_buf=*/block_buffers[dst_block_idx]);
+ }
+ }
+ }
+ }
+
+ // Aggregate partial sums from l0 ranges.
+ template <int Alignment>
+ void aggregateL0Blocks() const {
+ Index l0_index = 1;
+
+ for (; l0_index + 2 < l0_ranges; l0_index += 3) {
+ addAllToBuffer<Alignment>(
+ m * n,
+ /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
+ /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
+ /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
+ /*dst_buf= */ block_buffers[0]);
+ }
+
+ for (; l0_index < l0_ranges; ++l0_index) {
+ addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
+ block_buffers[0]);
+ }
+ }
+
+ void applyOutputKernel() const {
+ typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
+ evaluator->m_output_kernel(
+ OutputMapper(result, m), evaluator->m_tensor_contraction_params,
+ static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
+ }
+
+ // Compute block size with accounting for potentially incomplete last block.
+ Index actualBlockSize(Index block_idx) const {
+ return block_idx + 1 < num_blocks
+ ? block_size
+ : k + block_size - block_size * num_blocks;
+ };
+
+ // Compute range size with accounting for potentially incomplete last range.
+ Index actualRangeSize(Index num_ranges, Index range_size,
+ Index range_idx) const {
+ eigen_assert(range_idx < num_ranges);
+ return range_idx + 1 < num_ranges
+ ? range_size
+ : num_blocks + range_size - range_size * num_ranges;
+ };
+
+ template <int Alignment>
+ EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf,
+ Scalar* tgt_buf) {
+ const int output_packet_size =
+ internal::unpacket_traits<PacketReturnType>::size;
+ size_t i = 0;
+ const size_t num_packets = n / output_packet_size;
+ for (; i < output_packet_size * num_packets; i += output_packet_size) {
+ const PacketReturnType src_val =
+ internal::pload<PacketReturnType>(src_buf + i);
+ const PacketReturnType tgt_val =
+ internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
+ const PacketReturnType sum = internal::padd(src_val, tgt_val);
+ internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i,
+ sum);
+ }
+ for (; i < n; ++i) {
+ tgt_buf[i] += src_buf[i];
+ }
+ }
+
+ template <int Alignment>
+ EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n,
+ const Scalar* src_buf0,
+ const Scalar* src_buf1,
+ const Scalar* src_buf2,
+ Scalar* dst_buf) {
+ using ::Eigen::internal::padd;
+ using ::Eigen::internal::pload;
+ using ::Eigen::internal::ploadt;
+ using ::Eigen::internal::pstoret;
+
+ const int output_packet_size =
+ internal::unpacket_traits<PacketReturnType>::size;
+
+ size_t i = 0;
+ const size_t num_packets = n / output_packet_size;
+ for (; i < output_packet_size * num_packets; i += output_packet_size) {
+ const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
+ const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
+ const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
+
+ const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
+ const auto sum =
+ padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
+
+ pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
+ }
+ for (; i < n; ++i) {
+ dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
+ }
+ }
+
+ // Cost model doesn't capture well the cost associated with constructing
+ // tensor contraction mappers and computing loop bounds in gemm_pack_lhs
+ // and gemm_pack_rhs, so we specify minimum desired block size.
+ static Index blockSize(Index k, int num_threads) {
+ const auto round_up = [=](Index index) -> Index {
+ const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
+ return divup<Index>(index, kmultiple) * kmultiple;
+ };
+
+ const Index target_block_size = round_up(divup<Index>(k, num_threads));
+ const Index desired_min_block_size = 12 * packet_size;
+
+ return numext::mini<Index>(
+ k, numext::maxi<Index>(desired_min_block_size, target_block_size));
+ }
+
+ EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
+ void operator=(const EvalShardedByInnerDimContext&) = delete;
};
+ // ------------------------------------------------------------------------ //
+
+ // Below are the function used by evalProductImpl heuristics, trying to select
+ // optimcal parameters for parallelization algorithm.
+
// Decide whether we want to shard m x n contraction by columns or by rows.
static bool shardByCol(Index m, Index n, Index num_threads) {
// Note: we are comparing both n and m against Traits::nr, it is not
@@ -916,55 +1370,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
return cost + lhsCost + rhsCost;
}
- template <int Alignment>
- EIGEN_STRONG_INLINE void addToBuffer(size_t n, const Scalar* src_buf,
- Scalar* tgt_buf) const {
- const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
- size_t i = 0;
- const size_t num_packets = n / output_packet_size;
- for (; i < output_packet_size * num_packets; i += output_packet_size) {
- const PacketReturnType src_val =
- internal::pload<PacketReturnType>(src_buf + i);
- const PacketReturnType tgt_val =
- internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
- const PacketReturnType sum = internal::padd(src_val, tgt_val);
- internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, sum);
- }
- for (; i < n; ++i) {
- tgt_buf[i] += src_buf[i];
- }
- }
-
- template <int Alignment>
- EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0,
- const Scalar* src_buf1,
- const Scalar* src_buf2,
- Scalar* dst_buf) const {
- using ::Eigen::internal::padd;
- using ::Eigen::internal::pload;
- using ::Eigen::internal::ploadt;
- using ::Eigen::internal::pstoret;
-
- const int output_packet_size =
- internal::unpacket_traits<PacketReturnType>::size;
-
- size_t i = 0;
- const size_t num_packets = n / output_packet_size;
- for (; i < output_packet_size * num_packets; i += output_packet_size) {
- const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
- const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
- const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
-
- const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
- const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
-
- pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
- }
- for (; i < n; ++i) {
- dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
- }
- }
-
// Decide whether we want to shard m x k x n contraction over the inner
// (contraction) dimension (k).
static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
@@ -992,163 +1397,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
return shard_by_k;
}
- template <int Alignment>
- void evalShardedByInnerDim(int num_threads, Scalar* result) const {
- const Index m = this->m_i_size;
- const Index n = this->m_j_size;
- const Index k = this->m_k_size;
-
- // We will compute partial results into the buffers of this size.
- const Index buffer_size_bytes = m * n * sizeof(Scalar);
-
- // The underlying GEMM kernel assumes that k is a multiple of
- // the packet size and subtle breakage occurs if this is violated.
- const Index packet_size = internal::packet_traits<RhsScalar>::size;
-
- const auto round_up = [=](Index index) -> Index {
- const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
- return divup<Index>(index, kmultiple) * kmultiple;
- };
-
- // Cost model doesn't capture well the cost associated with constructing
- // tensor contraction mappers and computing loop bounds in gemm_pack_lhs and
- // gemm_pack_rhs, so we specify minimum desired block size.
- const Index target_block_size = round_up(divup<Index>(k, num_threads));
- const Index desired_min_block_size = 12 * packet_size;
-
- const Index block_size = numext::mini<Index>(
- k, numext::maxi<Index>(desired_min_block_size, target_block_size));
- const Index num_blocks = divup<Index>(k, block_size);
-
- // Compute block size with accounting for potentially incomplete last block.
- const auto actual_block_size = [=](Index block_idx) -> Index {
- return block_idx + 1 < num_blocks
- ? block_size
- : k + block_size - block_size * num_blocks;
- };
-
- // We compute partial gemm results in parallel, and to get the final result
- // we need to add them all together. For the large number of threads (>= 48)
- // this adds a very expensive sequential step at the end.
- //
- // We split the [0, num_blocks) into small ranges, and when a task for the
- // block finishes its partial gemm computation, it checks if it was the last
- // gemm in the range, and if so, it will add all blocks of the range.
- //
- // After all tasks finihes, we need to add only these pre-aggregated blocks.
-
- // Compute range size with accounting for potentially incomplete last range.
- const auto actual_range_size = [=](Index num_ranges, Index range_size,
- Index range_idx) -> Index {
- eigen_assert(range_idx < num_ranges);
- return range_idx + 1 < num_ranges
- ? range_size
- : num_blocks + range_size - range_size * num_ranges;
- };
-
- // For now we use just a single level of ranges to compute pre-aggregated
- // partial sums, but in general we can use more layers to compute tree
- // aggregation in parallel and reduce the size of the sequential step.
- //
- // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
- // sense only if number of threads >= ~128?
- static const Index l0_size = 4;
- const Index l0_ranges = divup<Index>(num_blocks, l0_size);
-
- // Keep count of pending gemm tasks for each l0 range.
- MaxSizeVector<std::atomic<int>> l0_state(l0_ranges);
- for (int i = 0; i < l0_ranges; ++i) {
- const Index num_pending_tasks = actual_range_size(l0_ranges, l0_size, i);
- l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
- }
-
- MaxSizeVector<Scalar*> block_buffers(num_blocks);
-
- auto process_block = [&, this](Index block_idx, Index begin, Index end) {
- Scalar* buf = block_buffers[block_idx];
- ::memset(buf, 0, buffer_size_bytes);
-
- TENSOR_CONTRACTION_DISPATCH(
- this->template evalGemmPartialWithoutOutputKernel, Alignment,
- (buf, begin, end,
- /*num_threads=*/internal::convert_index<int>(num_blocks)));
-
- // Check if it was the last task in l0 range.
- const Index l0_index = block_idx / l0_size;
- const int v = l0_state[l0_index].fetch_sub(1);
- eigen_assert(v >= 1);
-
- // If we processed the last block of the range, we can aggregate all
- // partial results into the first block of the range.
- if (v == 1) {
- const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index);
- const Index dst_block_idx = l0_index * l0_size;
-
- if (rng_size == l0_size) {
- addAllToBuffer<Alignment>(
- m * n,
- /*src_buf0=*/block_buffers[dst_block_idx + 1],
- /*src_buf1=*/block_buffers[dst_block_idx + 2],
- /*src_buf2=*/block_buffers[dst_block_idx + 3],
- /*dst_buf= */ block_buffers[dst_block_idx]);
- } else {
- // Aggregate blocks of potentially incomplete last range.
- for (int i = 1; i < rng_size; ++i) {
- addToBuffer<Alignment>(m * n,
- /*src_buf=*/block_buffers[dst_block_idx + i],
- /*dst_buf=*/block_buffers[dst_block_idx]);
- }
- }
- }
- };
-
- Barrier barrier(internal::convert_index<int>(num_blocks));
- for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
- Scalar* buf = block_idx == 0
- ? result
- : static_cast<Scalar*>(
- this->m_device.allocate(buffer_size_bytes));
- block_buffers.push_back(buf);
-
- Index block_start = block_idx * block_size;
- Index block_end = block_start + actual_block_size(block_idx);
-
- this->m_device.enqueueNoNotification([=, &barrier, &process_block]() {
- process_block(block_idx, block_start, block_end);
- barrier.Notify();
- });
- }
- barrier.Wait();
-
- // Aggregate partial sums from l0 ranges.
- Index l0_index = 1;
- for (; l0_index + 2 < l0_ranges; l0_index += 3) {
- addAllToBuffer<Alignment>(
- m * n,
- /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
- /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
- /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
- /*dst_buf= */block_buffers[0]);
- }
- for (; l0_index < l0_ranges; ++l0_index) {
- addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
- block_buffers[0]);
- }
-
- // Don't forget to deallocate ALL temporary buffers.
- for (Index i = 1; i < num_blocks; ++i) {
- this->m_device.deallocate(block_buffers[i]);
- }
-
- // Finally call output kernel with finalized output buffer.
- typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
- this->m_output_kernel(OutputMapper(result, m),
- this->m_tensor_contraction_params,
- static_cast<Eigen::Index>(0),
- static_cast<Eigen::Index>(0),
- m, n);
- }
-
TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
// Compute cost.
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
@@ -1188,7 +1436,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
return num_threads;
}
-
double computeBandwidth(bool shard_by_col, Index bm, Index bn,
Index bk) const {
// Peak VFMA bandwidth is 0.5. However if we have not enough data for