aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-03 17:20:56 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-03 17:20:56 -0700
commit47fefa235f73315bc57d685a7bc9cd8d3577349f (patch)
treeb6a380d7ae558dcafa2fa586a54e6632564fe16b /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
parenta8d264fa9c56e42f77e2129d4e504f5c854821c2 (diff)
Allow move-only done callback in TensorAsyncDevice
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h34
1 files changed, 19 insertions, 15 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index 10339e5e7..cf07656b3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -101,8 +101,8 @@ class TensorExecutor {
* Default async execution strategy is not implemented. Currently it's only
* available for ThreadPoolDevice (see definition below).
*/
-template <typename Expression, typename Device, bool Vectorizable,
- bool Tileable>
+template <typename Expression, typename Device, typename DoneCallback,
+ bool Vectorizable, bool Tileable>
class TensorAsyncExecutor {};
/**
@@ -419,15 +419,17 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ tr
}
};
-template <typename Expression, bool Vectorizable, bool Tileable>
-class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> {
+template <typename Expression, typename DoneCallback, bool Vectorizable,
+ bool Tileable>
+class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
+ Vectorizable, Tileable> {
public:
typedef typename Expression::Index StorageIndex;
typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator;
static EIGEN_STRONG_INLINE void runAsync(const Expression& expr,
const ThreadPoolDevice& device,
- std::function<void()> done) {
+ DoneCallback done) {
TensorAsyncExecutorContext* const ctx =
new TensorAsyncExecutorContext(expr, device, std::move(done));
@@ -455,7 +457,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
struct TensorAsyncExecutorContext {
TensorAsyncExecutorContext(const Expression& expr,
const ThreadPoolDevice& thread_pool,
- std::function<void()> done)
+ DoneCallback done)
: evaluator(expr, thread_pool), on_done(std::move(done)) {}
~TensorAsyncExecutorContext() {
@@ -466,12 +468,13 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
Evaluator evaluator;
private:
- std::function<void()> on_done;
+ DoneCallback on_done;
};
};
-template <typename Expression, bool Vectorizable>
-class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ true> {
+template <typename Expression, typename DoneCallback, bool Vectorizable>
+class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
+ Vectorizable, /*Tileable*/ true> {
public:
typedef typename traits<Expression>::Index StorageIndex;
typedef typename traits<Expression>::Scalar Scalar;
@@ -485,7 +488,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
static EIGEN_STRONG_INLINE void runAsync(const Expression& expr,
const ThreadPoolDevice& device,
- std::function<void()> done) {
+ DoneCallback done) {
TensorAsyncExecutorContext* const ctx =
new TensorAsyncExecutorContext(expr, device, std::move(done));
@@ -494,9 +497,10 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
if (total_size < cache_size &&
!ExpressionHasTensorBroadcastingOp<Expression>::value) {
- internal::TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable,
- /*Tileable*/ false>::runAsync(
- expr, device, [ctx]() { delete ctx; });
+ auto delete_ctx = [ctx]() { delete ctx; };
+ internal::TensorAsyncExecutor<
+ Expression, ThreadPoolDevice, decltype(delete_ctx), Vectorizable,
+ /*Tileable*/ false>::runAsync(expr, device, std::move(delete_ctx));
return;
}
@@ -532,7 +536,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
struct TensorAsyncExecutorContext {
TensorAsyncExecutorContext(const Expression& expr,
const ThreadPoolDevice& thread_pool,
- std::function<void()> done)
+ DoneCallback done)
: device(thread_pool),
evaluator(expr, thread_pool),
on_done(std::move(done)) {}
@@ -548,7 +552,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
TilingContext tiling;
private:
- std::function<void()> on_done;
+ DoneCallback on_done;
};
};