aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-07-09 12:10:26 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-07-09 12:10:26 -0700
commit3cd148f98338f8c03ce2757c528423338990a90d (patch)
tree5828bb84e96965cb057ef0df19f7711a74187e48 /unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
parent23b958818e2d645b7e92d80072c06602cd6bbd28 (diff)
Fix expression evaluation heuristic for TensorSliceOp
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h22
1 files changed, 12 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index 8f6e987b3..9ab1415ac 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -479,9 +479,12 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X
// Fixme: figure out the exact threshold
namespace {
-template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
+template <typename Index, typename Device, bool BlockAccess> struct MemcpyTriggerForSlicing {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
- EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; }
+ EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const {
+ const bool prefer_block_evaluation = BlockAccess && total > 32*1024;
+ return !prefer_block_evaluation && contiguous > threshold_;
+ }
private:
Index threshold_;
@@ -490,18 +493,18 @@ template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
// It is very expensive to start the memcpy kernel on GPU: we therefore only
// use it for large copies.
#ifdef EIGEN_USE_GPU
-template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> {
+template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, GpuDevice, BlockAccess> {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
- EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
+ EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
};
#endif
// It is very expensive to start the memcpy kernel on GPU: we therefore only
// use it for large copies.
#ifdef EIGEN_USE_SYCL
-template <typename Index> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice> {
+template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice, BlockAccess> {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { }
- EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
+ EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
};
#endif
@@ -592,8 +595,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
m_impl.evalSubExprsIfNeeded(NULL);
if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization
- && data && m_impl.data()
- && !BlockAccess) {
+ && data && m_impl.data()) {
Index contiguous_values = 1;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = 0; i < NumDims; ++i) {
@@ -611,8 +613,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
}
}
// Use memcpy if it's going to be faster than using the regular evaluation.
- const MemcpyTriggerForSlicing<Index, Device> trigger(m_device);
- if (trigger(contiguous_values)) {
+ const MemcpyTriggerForSlicing<Index, Device, BlockAccess> trigger(m_device);
+ if (trigger(internal::array_prod(dimensions()), contiguous_values)) {
EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data();
for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
Index offset = srcCoeff(i);