aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-01-08 16:26:31 -0800
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-01-08 16:26:31 -0800
commite70ffef9678f86ef465e93b89351e812ab47311d (patch)
treee49c963dd6fdf7bf4e0d15a35f81a3d5c8c012f4 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent190d053e41ef8cb77e08e42a37b7e72f9c1d6d43 (diff)
Optimize evalShardedByInnerDim
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h197
1 files changed, 161 insertions, 36 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 3946e2fc4..9666bf167 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -756,6 +756,36 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
+ template <int Alignment>
+ EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0,
+ const Scalar* src_buf1,
+ const Scalar* src_buf2,
+ Scalar* dst_buf) const {
+ using ::Eigen::internal::padd;
+ using ::Eigen::internal::pload;
+ using ::Eigen::internal::ploadt;
+ using ::Eigen::internal::pstoret;
+
+ const int output_packet_size =
+ internal::unpacket_traits<PacketReturnType>::size;
+
+ size_t i = 0;
+ const size_t num_packets = n / output_packet_size;
+ for (; i < output_packet_size * num_packets; i += output_packet_size) {
+ const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
+ const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
+ const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
+
+ const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
+ const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
+
+ pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
+ }
+ for (; i < n; ++i) {
+ dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
+ }
+ }
+
// Decide whether we want to shard m x k x n contraction over the inner
// (contraction) dimension (k).
static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
@@ -788,50 +818,145 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
- const Index packet_size = internal::packet_traits<RhsScalar>::size;
- const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
+
+ // We will compute partial results into the buffers of this size.
+ const Index buffer_size_bytes = m * n * sizeof(Scalar);
+
// The underlying GEMM kernel assumes that k is a multiple of
// the packet size and subtle breakage occurs if this is violated.
- Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
- Index num_blocks = divup<Index>(k, block_size);
- // we use 'result' for the first block's partial result.
- MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
- Barrier barrier(internal::convert_index<int>(num_blocks));
- auto process_block = [=, &barrier](Scalar* buf, Index begin, Index end) {
- ::memset(buf, 0, m * n * sizeof(Scalar));
+ const Index packet_size = internal::packet_traits<RhsScalar>::size;
+
+ const auto round_up = [=](Index index) -> Index {
+ const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
+ return divup<Index>(index, kmultiple) * kmultiple;
+ };
+
+ // Cost model doesn't capture well the cost associated with constructing
+ // tensor contraction mappers and computing loop bounds in gemm_pack_lhs and
+ // gemm_pack_rhs, so we specify minimum desired block size.
+ const Index target_block_size = round_up(divup<Index>(k, num_threads));
+ const Index desired_min_block_size = 12 * packet_size;
+
+ const Index block_size = numext::mini<Index>(
+ k, numext::maxi<Index>(desired_min_block_size, target_block_size));
+ const Index num_blocks = divup<Index>(k, block_size);
+
+ // Compute block size with accounting for potentially incomplete last block.
+ const auto actual_block_size = [=](Index block_idx) -> Index {
+ return block_idx + 1 < num_blocks
+ ? block_size
+ : k + block_size - block_size * num_blocks;
+ };
+
+ // We compute partial gemm results in parallel, and to get the final result
+ // we need to add them all together. For the large number of threads (>= 48)
+ // this adds a very expensive sequential step at the end.
+ //
+ // We split the [0, num_blocks) into small ranges, and when a task for the
+ // block finishes its partial gemm computation, it checks if it was the last
+ // gemm in the range, and if so, it will add all blocks of the range.
+ //
+ // After all tasks finihes, we need to add only these pre-aggregated blocks.
+
+ // Compute range size with accounting for potentially incomplete last range.
+ const auto actual_range_size = [=](Index num_ranges, Index range_size,
+ Index range_idx) -> Index {
+ eigen_assert(range_idx < num_ranges);
+ return range_idx + 1 < num_ranges
+ ? range_size
+ : num_blocks + range_size - range_size * num_ranges;
+ };
+
+ // For now we use just a single level of ranges to compute pre-aggregated
+ // partial sums, but in general we can use more layers to compute tree
+ // aggregation in parallel and reduce the size of the sequential step.
+ //
+ // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
+ // sense only if number of threads >= ~128?
+ static const Index l0_size = 4;
+ const Index l0_ranges = divup<Index>(num_blocks, l0_size);
+
+ // Keep count of pending gemm tasks for each l0 range.
+ MaxSizeVector<std::atomic<int>> l0_state(l0_ranges);
+ for (int i = 0; i < l0_ranges; ++i) {
+ l0_state.emplace_back(actual_range_size(l0_ranges, l0_size, i));
+ }
+
+ MaxSizeVector<Scalar*> block_buffers(num_blocks);
+
+ auto process_block = [&, this](Index block_idx, Index begin, Index end) {
+ Scalar* buf = block_buffers[block_idx];
+ ::memset(buf, 0, buffer_size_bytes);
+
TENSOR_CONTRACTION_DISPATCH(
this->template evalGemmPartialWithoutOutputKernel, Alignment,
- (buf, begin, end, this->m_device.numThreads()));
- barrier.Notify();
- };
- Index start = 0;
- for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
- // The underlying GEMM kernel assumes that k is a multiple of packet size
- // (currently largest packet size is 16) and subtle breakage occurs if
- // this is violated.
- block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
- Scalar* buf;
- if (start == 0) {
- buf = result;
- } else {
- buf = static_cast<Scalar*>(
- this->m_device.allocate(m * n * sizeof(Scalar)));
- block_buffers.push_back(buf);
- }
- Index end = start + block_size;
- if (end > k) {
- end = k;
+ (buf, begin, end, /*num_threads=*/num_blocks));
+
+ // Check if it was the last task in l0 range.
+ const Index l0_index = block_idx / l0_size;
+ const int v = l0_state[l0_index].fetch_sub(1);
+ eigen_assert(v >= 1);
+
+ // If we processed the last block of the range, we can aggregate all
+ // partial results into the first block of the range.
+ if (v == 1) {
+ const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index);
+ const Index dst_block_idx = l0_index * l0_size;
+
+ if (rng_size == l0_size) {
+ addAllToBuffer<Alignment>(
+ m * n,
+ /*src_buf0=*/block_buffers[dst_block_idx + 1],
+ /*src_buf1=*/block_buffers[dst_block_idx + 2],
+ /*src_buf2=*/block_buffers[dst_block_idx + 3],
+ /*dst_buf= */ block_buffers[dst_block_idx]);
+ } else {
+ // Aggregate blocks of potentially incomplete last range.
+ for (int i = 1; i < rng_size; ++i) {
+ addToBuffer<Alignment>(m * n,
+ /*src_buf=*/block_buffers[dst_block_idx + i],
+ /*dst_buf=*/block_buffers[dst_block_idx]);
+ }
+ }
}
- this->m_device.enqueueNoNotification(
- [=, &process_block]() { process_block(buf, start, end); });
- start = end;
+ };
+
+ Barrier barrier(internal::convert_index<int>(num_blocks));
+ for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
+ Scalar* buf = block_idx == 0
+ ? result
+ : static_cast<Scalar*>(
+ this->m_device.allocate(buffer_size_bytes));
+ block_buffers.push_back(buf);
+
+ Index block_start = block_idx * block_size;
+ Index block_end = block_start + actual_block_size(block_idx);
+
+ this->m_device.enqueueNoNotification([=, &barrier, &process_block]() {
+ process_block(block_idx, block_start, block_end);
+ barrier.Notify();
+ });
}
barrier.Wait();
- // Add other partial results into first partial result.
- for (const auto& buf : block_buffers) {
- addToBuffer<Alignment>(m * n, buf, result);
- this->m_device.deallocate(buf);
+ // Aggregate partial sums from l0 ranges.
+ Index l0_index = 1;
+ for (; l0_index + 2 < l0_ranges; l0_index += 3) {
+ addAllToBuffer<Alignment>(
+ m * n,
+ /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
+ /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
+ /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
+ /*dst_buf= */block_buffers[0]);
+ }
+ for (; l0_index < l0_ranges; ++l0_index) {
+ addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
+ block_buffers[0]);
+ }
+
+ // Don't forget to deallocate ALL temporary buffers.
+ for (Index i = 1; i < num_blocks; ++i) {
+ this->m_device.deallocate(block_buffers[i]);
}
// Finally call output kernel with finalized output buffer.