diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h index a5c293cf9..d826cfb7e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h @@ -689,15 +689,14 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_STRONG_INLINE - #if !defined(EIGEN_HIPCC) - // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same for all the functions - // being called within here, which then leads to proliferation of EIGEN_DEVICE_FUNC markings, one - // of which will eventually result in an NVCC error - EIGEN_DEVICE_FUNC - #endif - bool evalSubExprsIfNeeded(EvaluatorPointerType data) { - m_impl.evalSubExprsIfNeeded(NULL); - +#if !defined(EIGEN_HIPCC) + // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same + // for all the functions being called within here, which then leads to + // proliferation of EIGEN_DEVICE_FUNC markings, one of which will eventually + // result in an NVCC error + EIGEN_DEVICE_FUNC +#endif + bool evalSubExprsIfNeededCommon(EvaluatorPointerType data) { // Use the FullReducer if possible. if ((RunningFullReduction && RunningOnSycl) ||(RunningFullReduction && internal::FullReducer<Self, Op, Device>::HasOptimizedImplementation && @@ -802,6 +801,34 @@ struct TensorReductionEvaluatorBase<const TensorReductionOp<Op, Dims, ArgType, M return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_STRONG_INLINE +#if !defined(EIGEN_HIPCC) + EIGEN_DEVICE_FUNC +#endif + void + evalSubExprsIfNeededAsync(EvaluatorPointerType data, + EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(NULL, [this, data, done](bool) { + done(evalSubExprsIfNeededCommon(data)); + }); + } +#endif + + EIGEN_STRONG_INLINE +#if !defined(EIGEN_HIPCC) + // Marking this as EIGEN_DEVICE_FUNC for HIPCC requires also doing the same + // for all the functions being called within here, which then leads to + // proliferation of EIGEN_DEVICE_FUNC markings, one of which will eventually + // result in an NVCC error + EIGEN_DEVICE_FUNC +#endif + bool evalSubExprsIfNeeded(EvaluatorPointerType data) { + m_impl.evalSubExprsIfNeeded(NULL); + return evalSubExprsIfNeededCommon(data); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); if (m_result) { |