diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-10-22 12:42:44 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-10-22 12:42:44 -0700 |
commit | 97c0c5d485ddec0369326825a41db48d8505cf4c (patch) | |
tree | 9072616f37eacc24f407061ac74954d67da8c5ee /unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h | |
parent | 668ab3fc474e54c7919eda4fbaf11f3a99246494 (diff) |
Add block evaluation V2 to TensorAsyncExecutor.
Add async evaluation to a number of ops.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h index a5c293cf9..d826cfb7e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h @@ -689,15 +689,14 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_STRONG_INLINE - #if !defined(EIGEN_HIPCC) - // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same for all the functions - // being called within here, which then leads to proliferation of EIGEN_DEVICE_FUNC markings, one - // of which will eventually result in an NVCC error - EIGEN_DEVICE_FUNC - #endif - bool evalSubExprsIfNeeded(EvaluatorPointerType data) { - m_impl.evalSubExprsIfNeeded(NULL); - +#if !defined(EIGEN_HIPCC) + // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same + // for all the functions being called within here, which then leads to + // proliferation of EIGEN_DEVICE_FUNC markings, one of which will eventually + // result in an NVCC error + EIGEN_DEVICE_FUNC +#endif + bool evalSubExprsIfNeededCommon(EvaluatorPointerType data) { // Use the FullReducer if possible. if ((RunningFullReduction && RunningOnSycl) ||(RunningFullReduction && internal::FullReducer<Self, Op, Device>::HasOptimizedImplementation && @@ -802,6 +801,34 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_STRONG_INLINE +#if !defined(EIGEN_HIPCC) + EIGEN_DEVICE_FUNC +#endif + void + evalSubExprsIfNeededAsync(EvaluatorPointerType data, + EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(NULL, [this, data, done](bool) { + done(evalSubExprsIfNeededCommon(data)); + }); + } +#endif + + EIGEN_STRONG_INLINE +#if !defined(EIGEN_HIPCC) + // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same + // for all the functions being called within here, which then leads to + // proliferation of EIGEN_DEVICE_FUNC markings, one of which will eventually + // result in an NVCC error + EIGEN_DEVICE_FUNC +#endif + bool evalSubExprsIfNeeded(EvaluatorPointerType data) { + m_impl.evalSubExprsIfNeeded(NULL); + return evalSubExprsIfNeededCommon(data); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); if (m_result) { |