From f0b36fb9a405400e82b73ea70097b8ae3cd1095a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 30 Aug 2019 15:13:38 -0700 Subject: evalSubExprsIfNeededAsync + async TensorContractionThreadPool --- .../Eigen/CXX11/src/Tensor/TensorExecutor.h | 36 ++++++++++++++-------- 1 file changed, 23 insertions(+), 13 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 18d9de9e6..ce2337b63 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -430,12 +430,14 @@ class TensorAsyncExecutor std::function done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); - // TODO(ezhulenev): This is a potentially blocking operation. Make it async! - const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr); - typedef EvalRange EvalRange; + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { + typedef EvalRange EvalRange; const StorageIndex size = array_prod(ctx->evaluator.dimensions()); device.parallelForAsync( size, ctx->evaluator.costPerCoeff(Vectorizable), @@ -444,7 +446,9 @@ class TensorAsyncExecutor EvalRange::run(&ctx->evaluator, firstIdx, lastIdx); }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: @@ -496,26 +500,32 @@ class TensorAsyncExecutorevaluator.evalSubExprsIfNeeded(nullptr); + const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void { + if (!need_assign) { + delete ctx; + return; + } - if (needs_assign) { ctx->tiling = - internal::GetTensorExecutorTilingContext(device, ctx->evaluator); + GetTensorExecutorTilingContext(device, ctx->evaluator); device.parallelForAsync( ctx->tiling.block_mapper.total_block_count(), ctx->tiling.cost, [ctx](StorageIndex firstIdx, StorageIndex lastIdx) { ScalarNoConst* thread_buf = - ctx->tiling.template GetCurrentThreadBuffer(ctx->device); + ctx->tiling.template GetCurrentThreadBuffer( + ctx->device); for (StorageIndex i = firstIdx; i < lastIdx; ++i) { - auto block = ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); + auto block = + ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf); ctx->evaluator.evalBlock(&block); } }, [ctx]() { delete ctx; }); - } + }; + + ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs); } private: -- cgit v1.2.3