aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-05-19 15:19:01 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-05-19 15:19:01 -0700
commit2451679951cf6befb69204c211e4d84902dd86e4 (patch)
treeb5e8ef31aeede6bc3d65b529a7a1c34f31b5aa60 /unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
parenta81d17b73a8544920eb69f9fe436b40ea126c798 (diff)
Avoid using the cuda memcpy for small tensor slices since the memcpy kernel is very expensive to launch
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h23
1 files changed, 22 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index d15055727..fa1e6931c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -283,6 +283,26 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X
};
+// Fixme: figure out the exact threshold
+namespace {
+template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
+ EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
+ EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; }
+
+ private:
+ Index threshold_;
+};
+
+// It is very expensive to start the memcpy kernel on GPU: we therefore only
+// use it for large copies.
+#ifdef EIGEN_USE_GPU
+template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> {
+ EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
+ EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
+};
+#endif
+}
+
// Eval as rvalue
template<typename StartIndices, typename Sizes, typename ArgType, typename Device>
struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
@@ -364,7 +384,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
}
}
// Use memcpy if it's going to be faster than using the regular evaluation.
- if (contiguous_values > static_cast<Index>(2 * m_device.numThreads())) {
+ const MemcpyTriggerForSlicing<Index, Device> trigger(m_device);
+ if (trigger(contiguous_values)) {
Scalar* src = m_impl.data();
for (int i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
Index offset = srcCoeff(i);