From 3cd148f98338f8c03ce2757c528423338990a90d Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 9 Jul 2019 12:10:26 -0700 Subject: Fix expression evaluation heuristic for TensorSliceOp --- .../Eigen/CXX11/src/Tensor/TensorMorphing.h | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index 8f6e987b3..9ab1415ac 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -479,9 +479,12 @@ class TensorSlicingOp : public TensorBase struct MemcpyTriggerForSlicing { +template struct MemcpyTriggerForSlicing { EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { } - EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; } + EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { + const bool prefer_block_evaluation = BlockAccess && total > 32*1024; + return !prefer_block_evaluation && contiguous > threshold_; + } private: Index threshold_; @@ -490,18 +493,18 @@ template struct MemcpyTriggerForSlicing { // It is very expensive to start the memcpy kernel on GPU: we therefore only // use it for large copies. #ifdef EIGEN_USE_GPU -template struct MemcpyTriggerForSlicing { +template struct MemcpyTriggerForSlicing { EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { } - EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; } + EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; } }; #endif // It is very expensive to start the memcpy kernel on GPU: we therefore only // use it for large copies. #ifdef EIGEN_USE_SYCL -template struct MemcpyTriggerForSlicing { +template struct MemcpyTriggerForSlicing { EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { } - EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; } + EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; } }; #endif @@ -592,8 +595,7 @@ struct TensorEvaluator, Devi EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) { m_impl.evalSubExprsIfNeeded(NULL); if (!NumTraits::type>::RequireInitialization - && data && m_impl.data() - && !BlockAccess) { + && data && m_impl.data()) { Index contiguous_values = 1; if (static_cast(Layout) == static_cast(ColMajor)) { for (int i = 0; i < NumDims; ++i) { @@ -611,8 +613,8 @@ struct TensorEvaluator, Devi } } // Use memcpy if it's going to be faster than using the regular evaluation. - const MemcpyTriggerForSlicing trigger(m_device); - if (trigger(contiguous_values)) { + const MemcpyTriggerForSlicing trigger(m_device); + if (trigger(internal::array_prod(dimensions()), contiguous_values)) { EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data(); for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) { Index offset = srcCoeff(i); -- cgit v1.2.3