aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
commitf0b36fb9a405400e82b73ea70097b8ae3cd1095a (patch)
treed3a2903422799257720d2d4989bcd845ab2ae27e /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
parent619cea94916e7531a839ee0ff657714857921db8 (diff)
evalSubExprsIfNeededAsync + async TensorContractionThreadPool
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h36
1 files changed, 23 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index 18d9de9e6..ce2337b63 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -430,12 +430,14 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
std::function<void()> done) {
TensorAsyncExecutorContext* const ctx =
new TensorAsyncExecutorContext(expr, device, std::move(done));
- // TODO(ezhulenev): This is a potentially blocking operation. Make it async!
- const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr);
- typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange;
+ const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void {
+ if (!need_assign) {
+ delete ctx;
+ return;
+ }
- if (needs_assign) {
+ typedef EvalRange<Evaluator, StorageIndex, Vectorizable> EvalRange;
const StorageIndex size = array_prod(ctx->evaluator.dimensions());
device.parallelForAsync(
size, ctx->evaluator.costPerCoeff(Vectorizable),
@@ -444,7 +446,9 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
EvalRange::run(&ctx->evaluator, firstIdx, lastIdx);
},
[ctx]() { delete ctx; });
- }
+ };
+
+ ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);
}
private:
@@ -496,26 +500,32 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
return;
}
- // TODO(ezhulenev): This is a potentially blocking operation. Make it async!
- const bool needs_assign = ctx->evaluator.evalSubExprsIfNeeded(nullptr);
+ const auto on_eval_subexprs = [ctx, &device](bool need_assign) -> void {
+ if (!need_assign) {
+ delete ctx;
+ return;
+ }
- if (needs_assign) {
ctx->tiling =
- internal::GetTensorExecutorTilingContext<Evaluator, BlockMapper,
- Vectorizable>(device, ctx->evaluator);
+ GetTensorExecutorTilingContext<Evaluator, TensorBlockMapper,
+ Vectorizable>(device, ctx->evaluator);
device.parallelForAsync(
ctx->tiling.block_mapper.total_block_count(), ctx->tiling.cost,
[ctx](StorageIndex firstIdx, StorageIndex lastIdx) {
ScalarNoConst* thread_buf =
- ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>(ctx->device);
+ ctx->tiling.template GetCurrentThreadBuffer<ScalarNoConst>(
+ ctx->device);
for (StorageIndex i = firstIdx; i < lastIdx; ++i) {
- auto block = ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf);
+ auto block =
+ ctx->tiling.block_mapper.GetBlockForIndex(i, thread_buf);
ctx->evaluator.evalBlock(&block);
}
},
[ctx]() { delete ctx; });
- }
+ };
+
+ ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);
}
private: