diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-08-30 15:13:38 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-08-30 15:13:38 -0700 |
commit | f0b36fb9a405400e82b73ea70097b8ae3cd1095a (patch) | |
tree | d3a2903422799257720d2d4989bcd845ab2ae27e /unsupported/Eigen/CXX11 | |
parent | 619cea94916e7531a839ee0ff657714857921db8 (diff) |
evalSubExprsIfNeededAsync + async TensorContractionThreadPool
Diffstat (limited to 'unsupported/Eigen/CXX11')
8 files changed, 693 insertions, 302 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h index d6e51bc6c..270ad974e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h @@ -147,6 +147,18 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> // by the rhs to the lhs. return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data()); } + +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType, EvalSubExprsCallback done) { + m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done](bool) { + m_rightImpl.evalSubExprsIfNeededAsync( + m_leftImpl.data(), [done](bool need_assign) { done(need_assign); }); + }); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); m_rightImpl.cleanup(); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 10bdbc6a0..b290de311 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -214,6 +214,14 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType, EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index a398b2b3f..2f8656fbb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -373,13 +373,13 @@ struct TensorContractionEvaluatorBase typedef typename Storage::Type EvaluatorPointerType; enum { - IsAligned = true, - PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1), - BlockAccess = false, + IsAligned = true, + PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1), + BlockAccess = false, PreferBlockAccess = false, - Layout = TensorEvaluator<LeftArgType, Device>::Layout, - CoordAccess = false, // to be implemented - RawAccess = true + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + CoordAccess = false, // to be implemented + RawAccess = true }; // Most of the code is assuming that both input tensors are ColMajor. If the @@ -390,7 +390,7 @@ struct TensorContractionEvaluatorBase static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; typedef typename internal::conditional< static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; - + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluatorType; typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluatorType; @@ -605,48 +605,99 @@ struct TensorContractionEvaluatorBase } } -#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ - if (this->m_lhs_inner_dim_contiguous) { \ - if (this->m_rhs_inner_dim_contiguous) { \ - if (this->m_rhs_inner_dim_reordered) { \ - METHOD<true, true, true, ALIGNMENT>ARGS; \ - } \ - else { \ - METHOD<true, true, false, ALIGNMENT>ARGS; \ - } \ - } \ - else { \ - if (this->m_rhs_inner_dim_reordered) { \ - METHOD<true, false, true, ALIGNMENT>ARGS; \ - } \ - else { \ - METHOD<true, false, false, ALIGNMENT>ARGS; \ - } \ - } \ - } \ - else { \ - if (this->m_rhs_inner_dim_contiguous) { \ - if (this->m_rhs_inner_dim_reordered) { \ - METHOD<false, true, true, ALIGNMENT>ARGS; \ - } \ - else { \ - METHOD<false, true, false, ALIGNMENT>ARGS; \ - } \ - } \ - else { \ - if (this->m_rhs_inner_dim_reordered) { \ - METHOD<false, false, true, ALIGNMENT>ARGS; \ - } \ - else { \ - METHOD<false, false, false, ALIGNMENT>ARGS; \ - } \ - } \ - } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType dest, EvalSubExprsCallback done) { + m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) { + m_rightImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) { + if (dest) { + evalToAsync(dest, [done]() { done(false); }); + } else { + m_result = static_cast<EvaluatorPointerType>( + m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + evalToAsync(m_result, [done]() { done(true); }); + } + }); + }); + } +#endif // EIGEN_USE_THREADS + +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, true, true, ALIGNMENT> ARGS; \ + } else { \ + METHOD<true, true, false, ALIGNMENT> ARGS; \ + } \ + } else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, false, true, ALIGNMENT> ARGS; \ + } else { \ + METHOD<true, false, false, ALIGNMENT> ARGS; \ + } \ + } \ + } else { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, true, true, ALIGNMENT> ARGS; \ + } else { \ + METHOD<false, true, false, ALIGNMENT> ARGS; \ + } \ + } else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, false, true, ALIGNMENT> ARGS; \ + } else { \ + METHOD<false, false, false, ALIGNMENT> ARGS; \ + } \ + } \ + } + +#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \ + if (this->m_lhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \ + } else { \ + (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \ + } \ + } else { \ + if (this->m_rhs_inner_dim_reordered) { \ + (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \ + } else { \ + (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \ + } \ + } \ + } else { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \ + } else { \ + (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \ + } \ + } else { \ + if (this->m_rhs_inner_dim_reordered) { \ + (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \ + } else { \ + (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \ + } \ + } \ + } EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer); } +#ifdef EIGEN_USE_THREADS + template <typename EvalToCallback> + void evalToAsync(Scalar* buffer, EvalToCallback done) const { + static_cast<const Derived*>(this) + ->template evalProductAsync<EvalToCallback, Unaligned>(buffer, + std::move(done)); + } +#endif // EIGEN_USE_THREADS + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> void evalProductSequential(Scalar* buffer) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index ca20038a4..f9d9d6d31 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -73,6 +73,34 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT template <int Alignment> void evalProduct(Scalar* buffer) const { + evalProductImpl<NoCallback, Alignment>(buffer, NoCallback()); + } + + template <typename EvalToCallback, int Alignment> + void evalProductAsync(Scalar* buffer, EvalToCallback done) const { + evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done)); + } + + template <typename DoneCallback, int Alignment> + void evalProductImpl(Scalar* buffer, DoneCallback done) const { + // This function computes a lot of heuristics in multiple steps, and it + // also has multiple exit points. To keep it sane, readable and all in one + // place, sync/async execution decision is made at runtime at the very end. + // + // (1) In sync mode we allocate Context on the stack, submit computations + // to the device thread pool, and block on a barrier until it is + // completed. + // + // (2) In async mode we allocate Context on the heap, and after all tasks + // are finished, we call provided the done callback, and delete a + // context from the heap. + // + // (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state + // and temporary buffers, requried for executing the tensor contraction. + // They are responsible for cleaning it up after contraction is done. + static const bool IsEvalInSyncMode = + std::is_same<DoneCallback, NoCallback>::value; + const Index m = this->m_i_size; const Index n = this->m_j_size; const Index k = this->m_k_size; @@ -134,8 +162,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) { // We are in the scenario where it is more effective to shard by the // inner dimension. - this->template evalShardedByInnerDim<Alignment>(num_threads_by_k, - buffer); + if (IsEvalInSyncMode) { + EvalShardedByInnerDimContext<DoneCallback> ctx( + this, num_threads_by_k, buffer, m, n, k, std::move(done)); + ctx.template run<Alignment>(); + } else { + auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>( + this, num_threads_by_k, buffer, m, n, k, std::move(done)); + ctx->template runAsync<Alignment>(); + } + return; } @@ -146,6 +182,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT if (num_threads == 1) { TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer)); + if (!IsEvalInSyncMode) done(); return; } @@ -230,21 +267,89 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // optimization. if (parallelize_by_sharding_dim_only) parallel_pack = false; + // TODO(ezhulnev): With if contexpr we don't need SyncEvalParallelContext. + if (IsEvalInSyncMode) { #define CONTEXT_ARGS \ (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ - nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only) \ + nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \ + NoCallback()) \ .run() - - TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS); - + TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, + CONTEXT_ARGS); #undef CONTEXT_ARGS + } else { +#define CONTEXT_ARGS \ + (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ + nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \ + std::move(done)) + TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, + Alignment, CONTEXT_ARGS, run()); +#undef CONTEXT_ARGS + } } - // Context coordinates a single parallel gemm operation. - template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, - bool rhs_inner_dim_reordered, int Alignment> - class Context { + // ------------------------------------------------------------------------ // + + // Dummy struct to represent an empty DoneCallback. + + struct NoCallback { + void operator()() { + eigen_assert(false && "NoCallback should never be called"); + } + }; + + // ------------------------------------------------------------------------ // + + template <typename DoneCallback, typename Context> + class EvalParallelNotification; + + // Synchronous evaluation notification that blocks caller thread in Wait(). + template <typename Context> + class EvalParallelNotification<NoCallback, Context> { + public: + EvalParallelNotification(Context*, NoCallback) {} + void Notify() { done_.Notify(); } + void Wait() { done_.Wait(); } + private: + Eigen::Notification done_; + }; + + // Asynchronous evaluation notification that does not block in Wait(). + template <typename DoneCallback, typename Context> + class EvalParallelNotification { + public: + EvalParallelNotification(Context* ctx, DoneCallback done) + : ctx_(ctx), done_(std::move(done)) {} + + void Notify() { + // Make a copy of done callback, because it will be destructed when we + // will delete context in the next line (EvalParallelNotification is a + // data member of EvalParallelContext class). + DoneCallback done_copy = std::move(done_); + + // Delete parallel evaluation context. + delete ctx_; + + // Now safely call the done callback. + done_copy(); + } + + void Wait() {} + + private: + Context* ctx_; + DoneCallback done_; + }; + + // Context orchestrates sync/async parallel contraction evaluation. When it is + // executed in asynchronous mode, it owns all the shared state that might be + // accessible by block packing and kernel tasks. + + template <typename DoneCallback, bool lhs_inner_dim_contiguous, + bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, + int Alignment> + class EvalParallelContext { public: typedef internal::TensorContractionInputMapper< LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, @@ -267,11 +372,15 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef typename TensorContractionKernel::RhsBlock RhsBlock; typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle; - Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, - Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, - Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, - bool parallel_pack, bool parallelize_by_sharding_dim_only) - : device_(self->m_device), + EvalParallelContext(const Self* self, int num_threads, Scalar* buffer, + Index tm, Index tn, Index tk, Index bm, Index bn, + Index bk, Index nm, Index nn, Index nk, Index gm, + Index gn, Index nm0, Index nn0, bool shard_by_col, + bool parallel_pack, + bool parallelize_by_sharding_dim_only, + DoneCallback done) + : done_(this, std::move(done)), + device_(self->m_device), lhs_(self->m_leftImpl, self->m_left_nocontract_strides, self->m_i_strides, self->m_left_contracting_strides, self->m_k_strides), @@ -299,8 +408,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT gn_(gn), nm0_(nm0), nn0_(nn0), - kernel_(m_, k_, n_, bm_, bk_, bn_) - { + kernel_(m_, k_, n_, bm_, bk_, bn_) { // These two options are mutually exclusive. eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only)); @@ -371,7 +479,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } } - ~Context() { + ~EvalParallelContext() { for (Index x = 0; x < P; x++) { for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m]; delete[] state_kernel_[x]; @@ -386,16 +494,28 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT void run() { // Kick off packing of the first slice. signal_switch(0, 1); + // Wait for overall completion. - // TODO(dvyukov): this wait can lead to deadlock. - // If nthreads contractions are concurrently submitted from worker - // threads, this wait will block all worker threads and the system will - // deadlock. + // + // If parallel evaluation is executed in async mode, this is a no-op, and + // Wait() will return immediately. In synchronous mode it will block the + // caller thread until it will receive notification from last task. + // + // In async mode, last task when completed will call done callback from + // the same thread, and will delete this context. + // + // TODO(dvyukov): This wait can lead to deadlock if contraction is + // evaluated in synchronous mode. If nthreads contractions are + // concurrently submitted from worker threads, this wait will block all + // worker threads and the system will deadlock. done_.Wait(); } private: - Notification done_; + // This notification is specialized on the type of DoneCallback and can be + // blocking or non-blocking. + EvalParallelNotification<DoneCallback, EvalParallelContext> done_; + const Device& device_; LhsMapper lhs_; RhsMapper rhs_; @@ -780,10 +900,344 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; } Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; } - Context(const Context&) = delete; - void operator=(const Context&) = delete; + EvalParallelContext(const EvalParallelContext&) = delete; + void operator=(const EvalParallelContext&) = delete; + }; + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, + bool rhs_inner_dim_reordered, int Alignment> + using SyncEvalParallelContext = + EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, + rhs_inner_dim_contiguous, rhs_inner_dim_reordered, + Alignment>; + + // ------------------------------------------------------------------------ // + + // EvalShardedByInnerDimContext orchestrates sync/async contraction + // evaluation, when we shard by inner dimension. When it is executed in + // asynchronous mode, it owns all the shared state that might be accessible by + // block processing tasks. + + template <typename DoneCallback> + struct EvalShardedByInnerDimContext { + EvalShardedByInnerDimContext(const Self* evaluator, int num_threads, + Scalar* result, Index m, Index n, Index k, + DoneCallback done) + : evaluator(evaluator), + m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous), + m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous), + m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered), + num_threads(num_threads), + result(result), + m(m), + n(n), + k(k), + done(std::move(done)), + buffer_size_bytes(m * n * sizeof(Scalar)), + block_size(blockSize(k, num_threads)), + num_blocks(divup<Index>(k, block_size)), + num_pending_blocks(internal::convert_index<int>(num_blocks)), + l0_ranges(divup<Index>(num_blocks, l0_size)), + l0_state(l0_ranges), + block_buffers(num_blocks) { + // Keep count of pending gemm tasks for each l0 range. + for (int i = 0; i < l0_ranges; ++i) { + const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i); + l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks)); + } + + // Allocate temporary buffers for each block. + for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { + Scalar* buf = block_idx == 0 + ? result + : static_cast<Scalar*>(evaluator->m_device.allocate( + buffer_size_bytes)); + block_buffers.emplace_back(buf); + } + } + + ~EvalShardedByInnerDimContext() { + for (Index i = 1; i < num_blocks; ++i) { + evaluator->m_device.deallocate(block_buffers[i]); + } + } + + template <int Alignment> + void run() { + Barrier barrier(internal::convert_index<int>(num_blocks)); + for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { + evaluator->m_device.enqueueNoNotification( + [this, block_idx, &barrier]() { + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock<Alignment>(block_idx, block_start, block_end); + barrier.Notify(); + }); + } + barrier.Wait(); + + // Aggregate partial sums from l0 ranges. + aggregateL0Blocks<Alignment>(); + + // Apply output kernel. + applyOutputKernel(); + } + + template <int Alignment> + void runAsync() { + for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { + evaluator->m_device.enqueueNoNotification([this, block_idx]() { + Index block_start = block_idx * block_size; + Index block_end = block_start + actualBlockSize(block_idx); + + processBlock<Alignment>(block_idx, block_start, block_end); + + int v = num_pending_blocks.fetch_sub(1); + eigen_assert(v >= 1); + + if (v == 1) { + // Aggregate partial sums from l0 ranges. + aggregateL0Blocks<Alignment>(); + + // Apply output kernel. + applyOutputKernel(); + + // NOTE: If we call `done` callback before deleting this (context), + // it might deallocate Self* pointer captured by context, and we'll + // fail in destructor trying to deallocate temporary buffers. + + // Move done call back from context before it will be destructed. + DoneCallback done_copy = std::move(done); + + // We are confident that we are the last one who touches context. + delete this; + + // Now safely call the done callback. + done_copy(); + } + }); + } + } + + private: + // The underlying GEMM kernel assumes that k is a multiple of + // the packet size and subtle breakage occurs if this is violated. + static const Index packet_size = internal::packet_traits<RhsScalar>::size; + + const Self* evaluator; // TensorContraction evaluator + + // These fields required fromTENSOR_CONTRACTION_DISPATCH macro. + bool m_lhs_inner_dim_contiguous; + bool m_rhs_inner_dim_contiguous; + bool m_rhs_inner_dim_reordered; + + int num_threads; + Scalar* result; + + Index m; + Index n; + Index k; + + DoneCallback done; + + // ----------------------------------------------------------------------// + // Algorithm parameters. + + // We will compute partial results into the buffers of this size. + Index buffer_size_bytes; + + Index block_size; + Index num_blocks; + + // Keep track of pending tasks when evaluate in async mode. + std::atomic<int> num_pending_blocks; + + // We compute partial gemm results in parallel, and to get the final result + // we need to add them all together. For the large number of threads (>= 48) + // this adds a very expensive sequential step at the end. + // + // We split the [0, num_blocks) into small ranges, and when a task for the + // block finishes its partial gemm computation, it checks if it was the last + // gemm in the range, and if so, it will add all blocks of the range. + // + // After all tasks done, we need to add only these pre-aggregated blocks. + + // For now we use just a single level of ranges to compute pre-aggregated + // partial sums, but in general we can use more layers to compute tree + // aggregation in parallel and reduce the size of the sequential step. + // + // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make + // sense only if number of threads >= ~128? + static const Index l0_size = 4; + Index l0_ranges; + + // Keep count of pending gemm tasks for each l0 range. + MaxSizeVector<std::atomic<int>> l0_state; // [0, l0_ranges) + + // Buffers allocated for each temporary block computation. + MaxSizeVector<Scalar*> block_buffers; // [0, num_blocks) + + template <int Alignment> + void processBlock(Index block_idx, Index begin, Index end) { + Scalar* buf = block_buffers[block_idx]; + ::memset(buf, 0, buffer_size_bytes); + + TENSOR_CONTRACTION_DISPATCH( + evaluator->template evalGemmPartialWithoutOutputKernel, Alignment, + (buf, begin, end, + /*num_threads=*/internal::convert_index<int>(num_blocks))); + + // Check if it was the last task in l0 range. + const Index l0_index = block_idx / l0_size; + const int v = l0_state[l0_index].fetch_sub(1); + eigen_assert(v >= 1); + + // If we processed the last block of the range, we can aggregate all + // partial results into the first block of the range. + if (v == 1) { + const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index); + const Index dst_block_idx = l0_index * l0_size; + + if (rng_size == l0_size) { + addAllToBuffer<Alignment>( + m * n, + /*src_buf0=*/block_buffers[dst_block_idx + 1], + /*src_buf1=*/block_buffers[dst_block_idx + 2], + /*src_buf2=*/block_buffers[dst_block_idx + 3], + /*dst_buf= */ block_buffers[dst_block_idx]); + } else { + // Aggregate blocks of potentially incomplete last range. + for (int i = 1; i < rng_size; ++i) { + addToBuffer<Alignment>(m * n, + /*src_buf=*/block_buffers[dst_block_idx + i], + /*dst_buf=*/block_buffers[dst_block_idx]); + } + } + } + } + + // Aggregate partial sums from l0 ranges. + template <int Alignment> + void aggregateL0Blocks() const { + Index l0_index = 1; + + for (; l0_index + 2 < l0_ranges; l0_index += 3) { + addAllToBuffer<Alignment>( + m * n, + /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size], + /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size], + /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size], + /*dst_buf= */ block_buffers[0]); + } + + for (; l0_index < l0_ranges; ++l0_index) { + addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size], + block_buffers[0]); + } + } + + void applyOutputKernel() const { + typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; + evaluator->m_output_kernel( + OutputMapper(result, m), evaluator->m_tensor_contraction_params, + static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n); + } + + // Compute block size with accounting for potentially incomplete last block. + Index actualBlockSize(Index block_idx) const { + return block_idx + 1 < num_blocks + ? block_size + : k + block_size - block_size * num_blocks; + }; + + // Compute range size with accounting for potentially incomplete last range. + Index actualRangeSize(Index num_ranges, Index range_size, + Index range_idx) const { + eigen_assert(range_idx < num_ranges); + return range_idx + 1 < num_ranges + ? range_size + : num_blocks + range_size - range_size * num_ranges; + }; + + template <int Alignment> + EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf, + Scalar* tgt_buf) { + const int output_packet_size = + internal::unpacket_traits<PacketReturnType>::size; + size_t i = 0; + const size_t num_packets = n / output_packet_size; + for (; i < output_packet_size * num_packets; i += output_packet_size) { + const PacketReturnType src_val = + internal::pload<PacketReturnType>(src_buf + i); + const PacketReturnType tgt_val = + internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i); + const PacketReturnType sum = internal::padd(src_val, tgt_val); + internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, + sum); + } + for (; i < n; ++i) { + tgt_buf[i] += src_buf[i]; + } + } + + template <int Alignment> + EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n, + const Scalar* src_buf0, + const Scalar* src_buf1, + const Scalar* src_buf2, + Scalar* dst_buf) { + using ::Eigen::internal::padd; + using ::Eigen::internal::pload; + using ::Eigen::internal::ploadt; + using ::Eigen::internal::pstoret; + + const int output_packet_size = + internal::unpacket_traits<PacketReturnType>::size; + + size_t i = 0; + const size_t num_packets = n / output_packet_size; + for (; i < output_packet_size * num_packets; i += output_packet_size) { + const auto src_val0 = pload<PacketReturnType>(src_buf0 + i); + const auto src_val1 = pload<PacketReturnType>(src_buf1 + i); + const auto src_val2 = pload<PacketReturnType>(src_buf2 + i); + + const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i); + const auto sum = + padd(padd(dst_val, src_val0), padd(src_val1, src_val2)); + + pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum); + } + for (; i < n; ++i) { + dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i]; + } + } + + // Cost model doesn't capture well the cost associated with constructing + // tensor contraction mappers and computing loop bounds in gemm_pack_lhs + // and gemm_pack_rhs, so we specify minimum desired block size. + static Index blockSize(Index k, int num_threads) { + const auto round_up = [=](Index index) -> Index { + const Index kmultiple = packet_size <= 8 ? 8 : packet_size; + return divup<Index>(index, kmultiple) * kmultiple; + }; + + const Index target_block_size = round_up(divup<Index>(k, num_threads)); + const Index desired_min_block_size = 12 * packet_size; + + return numext::mini<Index>( + k, numext::maxi<Index>(desired_min_block_size, target_block_size)); + } + + EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete; + void operator=(const EvalShardedByInnerDimContext&) = delete; }; + // ------------------------------------------------------------------------ // + + // Below are the function used by evalProductImpl heuristics, trying to select + // optimcal parameters for parallelization algorithm. + // Decide whether we want to shard m x n contraction by columns or by rows. static bool shardByCol(Index m, Index n, Index num_threads) { // Note: we are comparing both n and m against Traits::nr, it is not @@ -916,55 +1370,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT return cost + lhsCost + rhsCost; } - template <int Alignment> - EIGEN_STRONG_INLINE void addToBuffer(size_t n, const Scalar* src_buf, - Scalar* tgt_buf) const { - const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size; - size_t i = 0; - const size_t num_packets = n / output_packet_size; - for (; i < output_packet_size * num_packets; i += output_packet_size) { - const PacketReturnType src_val = - internal::pload<PacketReturnType>(src_buf + i); - const PacketReturnType tgt_val = - internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i); - const PacketReturnType sum = internal::padd(src_val, tgt_val); - internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, sum); - } - for (; i < n; ++i) { - tgt_buf[i] += src_buf[i]; - } - } - - template <int Alignment> - EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0, - const Scalar* src_buf1, - const Scalar* src_buf2, - Scalar* dst_buf) const { - using ::Eigen::internal::padd; - using ::Eigen::internal::pload; - using ::Eigen::internal::ploadt; - using ::Eigen::internal::pstoret; - - const int output_packet_size = - internal::unpacket_traits<PacketReturnType>::size; - - size_t i = 0; - const size_t num_packets = n / output_packet_size; - for (; i < output_packet_size * num_packets; i += output_packet_size) { - const auto src_val0 = pload<PacketReturnType>(src_buf0 + i); - const auto src_val1 = pload<PacketReturnType>(src_buf1 + i); - const auto src_val2 = pload<PacketReturnType>(src_buf2 + i); - - const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i); - const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2)); - - pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum); - } - for (; i < n; ++i) { - dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i]; - } - } - // Decide whether we want to shard m x k x n contraction over the inner // (contraction) dimension (k). static bool shardByInnerDim(Index m, Index n, Index k, int num_threads, @@ -992,163 +1397,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT return shard_by_k; } - template <int Alignment> - void evalShardedByInnerDim(int num_threads, Scalar* result) const { - const Index m = this->m_i_size; - const Index n = this->m_j_size; - const Index k = this->m_k_size; - - // We will compute partial results into the buffers of this size. - const Index buffer_size_bytes = m * n * sizeof(Scalar); - - // The underlying GEMM kernel assumes that k is a multiple of - // the packet size and subtle breakage occurs if this is violated. - const Index packet_size = internal::packet_traits<RhsScalar>::size; - - const auto round_up = [=](Index index) -> Index { - const Index kmultiple = packet_size <= 8 ? 8 : packet_size; - return divup<Index>(index, kmultiple) * kmultiple; - }; - - // Cost model doesn't capture well the cost associated with constructing - // tensor contraction mappers and computing loop bounds in gemm_pack_lhs and - // gemm_pack_rhs, so we specify minimum desired block size. - const Index target_block_size = round_up(divup<Index>(k, num_threads)); - const Index desired_min_block_size = 12 * packet_size; - - const Index block_size = numext::mini<Index>( - k, numext::maxi<Index>(desired_min_block_size, target_block_size)); - const Index num_blocks = divup<Index>(k, block_size); - - // Compute block size with accounting for potentially incomplete last block. - const auto actual_block_size = [=](Index block_idx) -> Index { - return block_idx + 1 < num_blocks - ? block_size - : k + block_size - block_size * num_blocks; - }; - - // We compute partial gemm results in parallel, and to get the final result - // we need to add them all together. For the large number of threads (>= 48) - // this adds a very expensive sequential step at the end. - // - // We split the [0, num_blocks) into small ranges, and when a task for the - // block finishes its partial gemm computation, it checks if it was the last - // gemm in the range, and if so, it will add all blocks of the range. - // - // After all tasks finihes, we need to add only these pre-aggregated blocks. - - // Compute range size with accounting for potentially incomplete last range. - const auto actual_range_size = [=](Index num_ranges, Index range_size, - Index range_idx) -> Index { - eigen_assert(range_idx < num_ranges); - return range_idx + 1 < num_ranges - ? range_size - : num_blocks + range_size - range_size * num_ranges; - }; - - // For now we use just a single level of ranges to compute pre-aggregated - // partial sums, but in general we can use more layers to compute tree - // aggregation in parallel and reduce the size of the sequential step. - // - // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make - // sense only if number of threads >= ~128? - static const Index l0_size = 4; - const Index l0_ranges = divup<Index>(num_blocks, l0_size); - - // Keep count of pending gemm tasks for each l0 range. - MaxSizeVector<std::atomic<int>> l0_state(l0_ranges); - for (int i = 0; i < l0_ranges; ++i) { - const Index num_pending_tasks = actual_range_size(l0_ranges, l0_size, i); - l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks)); - } - - MaxSizeVector<Scalar*> block_buffers(num_blocks); - - auto process_block = [&, this](Index block_idx, Index begin, Index end) { - Scalar* buf = block_buffers[block_idx]; - ::memset(buf, 0, buffer_size_bytes); - - TENSOR_CONTRACTION_DISPATCH( - this->template evalGemmPartialWithoutOutputKernel, Alignment, - (buf, begin, end, - /*num_threads=*/internal::convert_index<int>(num_blocks))); - - // Check if it was the last task in l0 range. - const Index l0_index = block_idx / l0_size; - const int v = l0_state[l0_index].fetch_sub(1); - eigen_assert(v >= 1); - - // If we processed the last block of the range, we can aggregate all - // partial results into the first block of the range. - if (v == 1) { - const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index); - const Index dst_block_idx = l0_index * l0_size; - - if (rng_size == l0_size) { - addAllToBuffer<Alignment>( - m * n, - /*src_buf0=*/block_buffers[dst_block_idx + 1], - /*src_buf1=*/block_buffers[dst_block_idx + 2], - /*src_buf2=*/block_buffers[dst_block_idx + 3], - /*dst_buf= */ block_buffers[dst_block_idx]); - } else { - // Aggregate blocks of potentially incomplete last range. - for (int i = 1; i < rng_size; ++i) { - addToBuffer<Alignment>(m * n, - /*src_buf=*/block_buffers[dst_block_idx + i], - /*dst_buf=*/block_buffers[dst_block_idx]); - } - } - } - }; - - Barrier barrier(internal::convert_index<int>(num_blocks)); - for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) { - Scalar* buf = block_idx == 0 - ? result - : static_cast<Scalar*>( - this->m_device.allocate(buffer_size_bytes)); - block_buffers.push_back(buf); - - Index block_start = block_idx * block_size; - Index block_end = block_start + actual_block_size(block_idx); - - this->m_device.enqueueNoNotification([=, &barrier, &process_block]() { - process_block(block_idx, block_start, block_end); - barrier.Notify(); - }); - } - barrier.Wait(); - - // Aggregate partial sums from l0 ranges. - Index l0_index = 1; - for (; l0_index + 2 < l0_ranges; l0_index += 3) { - addAllToBuffer<Alignment>( - m * n, - /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size], - /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size], - /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size], - /*dst_buf= */block_buffers[0]); - } - for (; l0_index < l0_ranges; ++l0_index) { - addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size], - block_buffers[0]); - } - - // Don't forget to deallocate ALL temporary buffers. - for (Index i = 1; i < num_blocks; ++i) { - this->m_device.deallocate(block_buffers[i]); - } - - // Finally call output kernel with finalized output buffer. - typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; - this->m_output_kernel(OutputMapper(result, m), - this->m_tensor_contraction_params, - static_cast<Eigen::Index>(0), - static_cast<Eigen::Index>(0), - m, n); - } - TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const { // Compute cost. const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size; @@ -1188,7 +1436,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT return num_threads; } - double computeBandwidth(bool shard_by_col, Index bm, Index bn, Index bk) const { // Peak VFMA bandwidth is 0.5. However if we have not enough data for diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index edb0b3e25..cee46634c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -52,7 +52,7 @@ class Allocator { // Build a thread pool device on top the an existing pool of threads. struct ThreadPoolDevice { // The ownership of the thread pool remains with the caller. - ThreadPoolDevice(ThreadPoolInterface* pool, int num_cores, Allocator* allocator = NULL) + ThreadPoolDevice(ThreadPoolInterface* pool, int num_cores, Allocator* allocator = nullptr) : pool_(pool), num_threads_(num_cores), allocator_(allocator) { } EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const { @@ -234,7 +234,7 @@ struct ThreadPoolDevice { // Convenience wrapper for parallelFor that does not align blocks. void parallelFor(Index n, const TensorOpCost& cost, std::function<void(Index, Index)> f) const { - parallelFor(n, cost, NULL, std::move(f)); + parallelFor(n, cost, nullptr, std::move(f)); } // WARNING: This function is asynchronous and will not block the calling thread. @@ -248,6 +248,14 @@ struct ThreadPoolDevice { std::function<Index(Index)> block_align, std::function<void(Index, Index)> f, std::function<void()> done) const { + // Compute small problems directly in the caller thread. + if (n <= 1 || numThreads() == 1 || + CostModel::numThreads(n, cost, static_cast<int>(numThreads())) == 1) { + f(0, n); + done(); + return; + } + // Compute block size and total count of blocks. ParallelForBlock block = CalculateParallelForBlock(n, cost, block_align); @@ -269,24 +277,26 @@ struct ThreadPoolDevice { // Single block or less, execute directly. ctx->f(firstIdx, lastIdx); - // Call 'done' callback if it was the last block. - if (ctx->count.fetch_sub(1) == 1) { - (ctx->done)(); - // We can't delete ctx right now, because it will deallocate the closure - // we are currently in. - pool_->Schedule([ctx]() { delete ctx; }); - } + // Delete async context if it was the last block. + if (ctx->count.fetch_sub(1) == 1) delete ctx; }; - // Execute the root in the thread pool. - pool_->Schedule([ctx, n]() { ctx->handle_range(0, n); }); + if (block.count <= numThreads()) { + // Avoid a thread hop by running the root of the tree and one block on the + // main thread. + ctx->handle_range(0, n); + } else { + // Execute the root in the thread pool to avoid running work on more than + // numThreads() threads. + pool_->Schedule([ctx, n]() { ctx->handle_range(0, n); }); + } } // Convenience wrapper for parallelForAsync that does not align blocks. void parallelForAsync(Index n, const TensorOpCost& cost, std::function<void(Index, Index)> f, std::function<void()> done) const { - parallelForAsync(n, cost, NULL, std::move(f), std::move(done)); + parallelForAsync(n, cost, nullptr, std::move(f), std::move(done)); } // Thread pool accessor. @@ -307,6 +317,7 @@ struct ThreadPoolDevice { : count(block_count), f(std::move(block_f)), done(std::move(done_callback)) {} + ~ParallelForAsyncContext() { done(); } std::atomic<Index> count; std::function<void(Index, Index)> f; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index a3a79d4e9..fec735868 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -79,7 +79,16 @@ struct TensorEvaluator return true; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType dest, EvalSubExprsCallback done) { + // TODO(ezhulenev): ThreadPoolDevice memcpy is blockign operation. + done(evalSubExprsIfNeeded(dest)); + } +#endif // EIGEN_USE_THREADS + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { eigen_assert(m_data != NULL); @@ -247,6 +256,15 @@ struct TensorEvaluator<const Derived, Device> return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType dest, EvalSubExprsCallback done) { + // TODO(ezhulenev): ThreadPoolDevice memcpy is a blockign operation. + done(evalSubExprsIfNeeded(dest)); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -346,6 +364,15 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { return true; } + +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType, EvalSubExprsCallback done) { + done(true); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const @@ -425,6 +452,15 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> m_argImpl.evalSubExprsIfNeeded(NULL); return true; } + +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType, EvalSubExprsCallback done) { + m_argImpl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_argImpl.cleanup(); } @@ -546,6 +582,19 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg m_rightImpl.evalSubExprsIfNeeded(NULL); return true; } + +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType, EvalSubExprsCallback done) { + // TODO(ezhulenev): Evaluate two expression in parallel? + m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done](bool) { + m_rightImpl.evalSubExprsIfNeededAsync(nullptr, + [done](bool) { done(true); }); + }); + } +#endif // EIGEN_USE_THREADS + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); m_rightImpl.cleanup(); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 18d9de9e6..ce2337b63 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -430,12 +430,14 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> std::function<void()> done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); - // TODO(ezhulenev): This is a potentially blocking operation. Make it async! - const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr); - typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange; + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { + typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange; const StorageIndex size = array_prod(ctx->evaluator.dimensions()); device.parallelForAsync( size, ctx->evaluator.costPerCoeff(Vectorizable), @@ -444,7 +446,9 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> EvalRange::run(&ctx->evaluator, firstIdx, lastIdx); }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: @@ -496,26 +500,32 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable return; } - // TODO(ezhulenev): This is a potentially blocking operation. Make it async! - const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr); + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { ctx->tiling = - internal::GetTensorExecutorTilingContext<Evaluator, BlockMapper, - Vectorizable>(device, ctx->evaluator); + GetTensorExecutorTilingContext<Evaluator, TensorBlockMapper, + Vectorizable>(device, ctx->evaluator); device.parallelForAsync( ctx->tiling.block_mapper.total_block_count(), ctx->tiling.cost, [ctx](StorageIndex firstIdx, StorageIndex lastIdx) { ScalarNoConst* thread_buf = - ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>(ctx->device); + ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>( + ctx->device); for (StorageIndex i = firstIdx; i < lastIdx; ++i) { - auto block = ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); + auto block = + ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); ctx->evaluator.evalBlock(&block); } }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h b/unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h index bae68e1fb..e4c59dc3d 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h @@ -25,6 +25,9 @@ class Barrier { void Notify() { unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; if (v != 1) { + // Clear the lowest bit (waiter flag) and check that the original state + // value was not zero. If it was zero, it means that notify was called + // more times than the original count. eigen_plain_assert(((v + 2) & ~1) != 0); return; // either count has not dropped to 0, or waiter is not waiting } |