diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index d15055727..fa1e6931c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -283,6 +283,26 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X }; +// Fixme: figure out the exact threshold +namespace { +template <typename Index, typename Device> struct MemcpyTriggerForSlicing { + EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { } + EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; } + + private: + Index threshold_; +}; + +// It is very expensive to start the memcpy kernel on GPU: we therefore only +// use it for large copies. +#ifdef EIGEN_USE_GPU +template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> { + EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { } + EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; } +}; +#endif +} + // Eval as rvalue template<typename StartIndices, typename Sizes, typename ArgType, typename Device> struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Device> @@ -364,7 +384,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi } } // Use memcpy if it's going to be faster than using the regular evaluation. - if (contiguous_values > static_cast<Index>(2 * m_device.numThreads())) { + const MemcpyTriggerForSlicing<Index, Device> trigger(m_device); + if (trigger(contiguous_values)) { Scalar* src = m_impl.data(); for (int i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) { Index offset = srcCoeff(i); |