diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-08-30 15:13:38 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-08-30 15:13:38 -0700 |
commit | f0b36fb9a405400e82b73ea70097b8ae3cd1095a (patch) | |
tree | d3a2903422799257720d2d4989bcd845ab2ae27e /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | |
parent | 619cea94916e7531a839ee0ff657714857921db8 (diff) |
evalSubExprsIfNeededAsync + async TensorContractionThreadPool
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 18d9de9e6..ce2337b63 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -430,12 +430,14 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> std::function<void()> done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); - // TODO(ezhulenev): This is a potentially blocking operation. Make it async! - const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr); - typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange; + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { + typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange; const StorageIndex size = array_prod(ctx->evaluator.dimensions()); device.parallelForAsync( size, ctx->evaluator.costPerCoeff(Vectorizable), @@ -444,7 +446,9 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> EvalRange::run(&ctx->evaluator, firstIdx, lastIdx); }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: @@ -496,26 +500,32 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable return; } - // TODO(ezhulenev): This is a potentially blocking operation. Make it async! - const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr); + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { ctx->tiling = - internal::GetTensorExecutorTilingContext<Evaluator, BlockMapper, - Vectorizable>(device, ctx->evaluator); + GetTensorExecutorTilingContext<Evaluator, TensorBlockMapper, + Vectorizable>(device, ctx->evaluator); device.parallelForAsync( ctx->tiling.block_mapper.total_block_count(), ctx->tiling.cost, [ctx](StorageIndex firstIdx, StorageIndex lastIdx) { ScalarNoConst* thread_buf = - ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>(ctx->device); + ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>( + ctx->device); for (StorageIndex i = firstIdx; i < lastIdx; ++i) { - auto block = ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); + auto block = + ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); ctx->evaluator.evalBlock(&block); } }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: |