diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-09-03 17:20:56 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-09-03 17:20:56 -0700 |
commit | 47fefa235f73315bc57d685a7bc9cd8d3577349f (patch) | |
tree | b6a380d7ae558dcafa2fa586a54e6632564fe16b /unsupported/Eigen/CXX11/src | |
parent | a8d264fa9c56e42f77e2129d4e504f5c854821c2 (diff) |
Allow move-only done callback in TensorAsyncDevice
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
4 files changed, 30 insertions, 26 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 095c85dc4..f2aa37256 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -1065,12 +1065,12 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { #ifdef EIGEN_USE_THREADS // Select the async device on which to evaluate the expression. - template <typename DeviceType> + template <typename DeviceType, typename DoneCallback> typename internal::enable_if< internal::is_same<DeviceType, ThreadPoolDevice>::value, - TensorAsyncDevice<Derived, DeviceType>>::type - device(const DeviceType& dev, std::function<void()> done) { - return TensorAsyncDevice<Derived, DeviceType>(dev, derived(), std::move(done)); + TensorAsyncDevice<Derived, DeviceType, DoneCallback>>::type + device(const DeviceType& dev, DoneCallback done) { + return TensorAsyncDevice<Derived, DeviceType, DoneCallback>(dev, derived(), std::move(done)); } #endif // EIGEN_USE_THREADS diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h index 5122b3623..cc9c65702 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h @@ -73,21 +73,21 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice { * ThreadPoolDevice). * * Example: - * std::function<void()> done = []() {}; + * auto done = []() { ... expression evaluation done ... }; * C.device(EIGEN_THREAD_POOL, std::move(done)) = A + B; */ -template <typename ExpressionType, typename DeviceType> +template <typename ExpressionType, typename DeviceType, typename DoneCallback> class TensorAsyncDevice { public: TensorAsyncDevice(const DeviceType& device, ExpressionType& expression, - std::function<void()> done) + DoneCallback done) : m_device(device), m_expression(expression), m_done(std::move(done)) {} template <typename OtherDerived> EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) { typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; - typedef internal::TensorAsyncExecutor<const Assign, DeviceType> Executor; + typedef internal::TensorAsyncExecutor<const Assign, DeviceType, DoneCallback> Executor; // WARNING: After assignment 'm_done' callback will be in undefined state. Assign assign(m_expression, other); @@ -99,7 +99,7 @@ class TensorAsyncDevice { protected: const DeviceType& m_device; ExpressionType& m_expression; - std::function<void()> m_done; + DoneCallback m_done; }; #endif // EIGEN_USE_THREADS diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 10339e5e7..cf07656b3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -101,8 +101,8 @@ class TensorExecutor { * Default async execution strategy is not implemented. Currently it's only * available for ThreadPoolDevice (see definition below). */ -template <typename Expression, typename Device, bool Vectorizable, - bool Tileable> +template <typename Expression, typename Device, typename DoneCallback, + bool Vectorizable, bool Tileable> class TensorAsyncExecutor {}; /** @@ -419,15 +419,17 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ tr } }; -template <typename Expression, bool Vectorizable, bool Tileable> -class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> { +template <typename Expression, typename DoneCallback, bool Vectorizable, + bool Tileable> +class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, + Vectorizable, Tileable> { public: typedef typename Expression::Index StorageIndex; typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator; static EIGEN_STRONG_INLINE void runAsync(const Expression& expr, const ThreadPoolDevice& device, - std::function<void()> done) { + DoneCallback done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); @@ -455,7 +457,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> struct TensorAsyncExecutorContext { TensorAsyncExecutorContext(const Expression& expr, const ThreadPoolDevice& thread_pool, - std::function<void()> done) + DoneCallback done) : evaluator(expr, thread_pool), on_done(std::move(done)) {} ~TensorAsyncExecutorContext() { @@ -466,12 +468,13 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> Evaluator evaluator; private: - std::function<void()> on_done; + DoneCallback on_done; }; }; -template <typename Expression, bool Vectorizable> -class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ true> { +template <typename Expression, typename DoneCallback, bool Vectorizable> +class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, + Vectorizable, /*Tileable*/ true> { public: typedef typename traits<Expression>::Index StorageIndex; typedef typename traits<Expression>::Scalar Scalar; @@ -485,7 +488,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable static EIGEN_STRONG_INLINE void runAsync(const Expression& expr, const ThreadPoolDevice& device, - std::function<void()> done) { + DoneCallback done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); @@ -494,9 +497,10 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable if (total_size < cache_size && !ExpressionHasTensorBroadcastingOp<Expression>::value) { - internal::TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, - /*Tileable*/ false>::runAsync( - expr, device, [ctx]() { delete ctx; }); + auto delete_ctx = [ctx]() { delete ctx; }; + internal::TensorAsyncExecutor< + Expression, ThreadPoolDevice, decltype(delete_ctx), Vectorizable, + /*Tileable*/ false>::runAsync(expr, device, std::move(delete_ctx)); return; } @@ -532,7 +536,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable struct TensorAsyncExecutorContext { TensorAsyncExecutorContext(const Expression& expr, const ThreadPoolDevice& thread_pool, - std::function<void()> done) + DoneCallback done) : device(thread_pool), evaluator(expr, thread_pool), on_done(std::move(done)) {} @@ -548,7 +552,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable TilingContext tiling; private: - std::function<void()> on_done; + DoneCallback on_done; }; }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index e823bd932..772dbbe35 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -94,7 +94,7 @@ template<typename XprType, template <class> class MakePointer_ = MakePointer> cl template<typename XprType> class TensorForcedEvalOp; template<typename ExpressionType, typename DeviceType> class TensorDevice; -template<typename ExpressionType, typename DeviceType> class TensorAsyncDevice; +template<typename ExpressionType, typename DeviceType, typename DoneCallback> class TensorAsyncDevice; template<typename Derived, typename Device> struct TensorEvaluator; struct NoOpOutputKernel; @@ -168,7 +168,7 @@ template <typename Expression, typename Device, bool Tileable = IsTileable<Device, Expression>::value> class TensorExecutor; -template <typename Expression, typename Device, +template <typename Expression, typename Device, typename DoneCallback, bool Vectorizable = IsVectorizable<Device, Expression>::value, bool Tileable = IsTileable<Device, Expression>::value> class TensorAsyncExecutor; |