diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 197 |
1 files changed, 161 insertions, 36 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 3946e2fc4..9666bf167 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -756,6 +756,36 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } } + template <int Alignment> + EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0, + const Scalar* src_buf1, + const Scalar* src_buf2, + Scalar* dst_buf) const { + using ::Eigen::internal::padd; + using ::Eigen::internal::pload; + using ::Eigen::internal::ploadt; + using ::Eigen::internal::pstoret; + + const int output_packet_size = + internal::unpacket_traits<PacketReturnType>::size; + + size_t i = 0; + const size_t num_packets = n / output_packet_size; + for (; i < output_packet_size * num_packets; i += output_packet_size) { + const auto src_val0 = pload<PacketReturnType>(src_buf0 + i); + const auto src_val1 = pload<PacketReturnType>(src_buf1 + i); + const auto src_val2 = pload<PacketReturnType>(src_buf2 + i); + + const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i); + const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2)); + + pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum); + } + for (; i < n; ++i) { + dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i]; + } + } + // Decide whether we want to shard m x k x n contraction over the inner // (contraction) dimension (k). static bool shardByInnerDim(Index m, Index n, Index k, int num_threads, @@ -788,50 +818,145 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index m = this->m_i_size; const Index n = this->m_j_size; const Index k = this->m_k_size; - const Index packet_size = internal::packet_traits<RhsScalar>::size; - const Index kmultiple = packet_size <= 8 ? 8 : packet_size; + + // We will compute partial results into the buffers of this size. + const Index buffer_size_bytes = m * n * sizeof(Scalar); + // The underlying GEMM kernel assumes that k is a multiple of // the packet size and subtle breakage occurs if this is violated. - Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads); - Index num_blocks = divup<Index>(k, block_size); - // we use 'result' for the first block's partial result. - MaxSizeVector<Scalar*> block_buffers(num_blocks - 1); - Barrier barrier(internal::convert_index<int>(num_blocks)); - auto process_block = [=, &barrier](Scalar* buf, Index begin, Index end) { - ::memset(buf, 0, m * n * sizeof(Scalar)); + const Index packet_size = internal::packet_traits<RhsScalar>::size; + + const auto round_up = [=](Index index) -> Index { + const Index kmultiple = packet_size <= 8 ? 8 : packet_size; + return divup<Index>(index, kmultiple) * kmultiple; + }; + + // Cost model doesn't capture well the cost associated with constructing + // tensor contraction mappers and computing loop bounds in gemm_pack_lhs and + // gemm_pack_rhs, so we specify minimum desired block size. + const Index target_block_size = round_up(divup<Index>(k, num_threads)); + const Index desired_min_block_size = 12 * packet_size; + + const Index block_size = numext::mini<Index>( + k, numext::maxi<Index>(desired_min_block_size, target_block_size)); + const Index num_blocks = divup<Index>(k, block_size); + + // Compute block size with accounting for potentially incomplete last block. + const auto actual_block_size = [=](Index block_idx) -> Index { + return block_idx + 1 < num_blocks + ? block_size + : k + block_size - block_size * num_blocks; + }; + + // We compute partial gemm results in parallel, and to get the final result + // we need to add them all together. For the large number of threads (>= 48) + // this adds a very expensive sequential step at the end. + // + // We split the [0, num_blocks) into small ranges, and when a task for the + // block finishes its partial gemm computation, it checks if it was the last + // gemm in the range, and if so, it will add all blocks of the range. + // + // After all tasks finihes, we need to add only these pre-aggregated blocks. + + // Compute range size with accounting for potentially incomplete last range. + const auto actual_range_size = [=](Index num_ranges, Index range_size, + Index range_idx) -> Index { + eigen_assert(range_idx < num_ranges); + return range_idx + 1 < num_ranges + ? range_size + : num_blocks + range_size - range_size * num_ranges; + }; + + // For now we use just a single level of ranges to compute pre-aggregated + // partial sums, but in general we can use more layers to compute tree + // aggregation in parallel and reduce the size of the sequential step. + // + // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make + // sense only if number of threads >= ~128? + static const Index l0_size = 4; + const Index l0_ranges = divup<Index>(num_blocks, l0_size); + + // Keep count of pending gemm tasks for each l0 range. + MaxSizeVector<std::atomic<int>> l0_state(l0_ranges); + for (int i = 0; i < l0_ranges; ++i) { + l0_state.emplace_back(actual_range_size(l0_ranges, l0_size, i)); + } + + MaxSizeVector<Scalar*> block_buffers(num_blocks); + + auto process_block = [&, this](Index block_idx, Index begin, Index end) { + Scalar* buf = block_buffers[block_idx]; + ::memset(buf, 0, buffer_size_bytes); + TENSOR_CONTRACTION_DISPATCH( this->template evalGemmPartialWithoutOutputKernel, Alignment, - (buf, begin, end, this->m_device.numThreads())); - barrier.Notify(); - }; - Index start = 0; - for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) { - // The underlying GEMM kernel assumes that k is a multiple of packet size - // (currently largest packet size is 16) and subtle breakage occurs if - // this is violated. - block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left); - Scalar* buf; - if (start == 0) { - buf = result; - } else { - buf = static_cast<Scalar*>( - this->m_device.allocate(m * n * sizeof(Scalar))); - block_buffers.push_back(buf); - } - Index end = start + block_size; - if (end > k) { - end = k; + (buf, begin, end, /*num_threads=*/num_blocks)); + + // Check if it was the last task in l0 range. + const Index l0_index = block_idx / l0_size; + const int v = l0_state[l0_index].fetch_sub(1); + eigen_assert(v >= 1); + + // If we processed the last block of the range, we can aggregate all + // partial results into the first block of the range. + if (v == 1) { + const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index); + const Index dst_block_idx = l0_index * l0_size; + + if (rng_size == l0_size) { + addAllToBuffer<Alignment>( + m * n, + /*src_buf0=*/block_buffers[dst_block_idx + 1], + /*src_buf1=*/block_buffers[dst_block_idx + 2], + /*src_buf2=*/block_buffers[dst_block_idx + 3], + /*dst_buf= */ block_buffers[dst_block_idx]); + } else { + // Aggregate blocks of potentially incomplete last range. + for (int i = 1; i < rng_size; ++i) { + addToBuffer<Alignment>(m * n, + /*src_buf=*/block_buffers[dst_block_idx + i], + /*dst_buf=*/block_buffers[dst_block_idx]); + } + } } - this->m_device.enqueueNoNotification( - [=, &process_block]() { process_block(buf, start, end); }); - start = end; + }; + + Barrier barrier(internal::convert_index<int>(num_blocks)); + for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { + Scalar* buf = block_idx == 0 + ? result + : static_cast<Scalar*>( + this->m_device.allocate(buffer_size_bytes)); + block_buffers.push_back(buf); + + Index block_start = block_idx * block_size; + Index block_end = block_start + actual_block_size(block_idx); + + this->m_device.enqueueNoNotification([=, &barrier, &process_block]() { + process_block(block_idx, block_start, block_end); + barrier.Notify(); + }); } barrier.Wait(); - // Add other partial results into first partial result. - for (const auto& buf : block_buffers) { - addToBuffer<Alignment>(m * n, buf, result); - this->m_device.deallocate(buf); + // Aggregate partial sums from l0 ranges. + Index l0_index = 1; + for (; l0_index + 2 < l0_ranges; l0_index += 3) { + addAllToBuffer<Alignment>( + m * n, + /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size], + /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size], + /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size], + /*dst_buf= */block_buffers[0]); + } + for (; l0_index < l0_ranges; ++l0_index) { + addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size], + block_buffers[0]); + } + + // Don't forget to deallocate ALL temporary buffers. + for (Index i = 1; i < num_blocks; ++i) { + this->m_device.deallocate(block_buffers[i]); } // Finally call output kernel with finalized output buffer. |