From 2386fc8528fa8f923b0300af6ddc4cd46a178afd Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 27 Feb 2015 12:57:13 -0800 Subject: Added support for 32bit index on a per tensor/tensor expression. This enables us to use 32bit indices to evaluate expressions on GPU faster while keeping the ability to use 64 bit indices to manipulate large tensors on CPU in the same binary. --- unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h | 46 +++++++++++++--------- 1 file changed, 27 insertions(+), 19 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h index 1b227e8c2..91aae091c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h @@ -66,14 +66,16 @@ template class TensorStorage : public TensorStorage::type> { + typedef typename internal::compute_index_type::type Index; + typedef DSizes Dimensions; typedef TensorStorage::type> Base_; public: - TensorStorage() { } - TensorStorage(const TensorStorage& other) : Base_(other) { } + EIGEN_DEVICE_FUNC TensorStorage() { } + EIGEN_DEVICE_FUNC TensorStorage(const TensorStorage& other) : Base_(other) { } - TensorStorage(internal::constructor_without_unaligned_array_assert) : Base_(internal::constructor_without_unaligned_array_assert()) {} - TensorStorage(DenseIndex size, const array& dimensions) : Base_(size, dimensions) {} + EIGEN_DEVICE_FUNC TensorStorage(internal::constructor_without_unaligned_array_assert) : Base_(internal::constructor_without_unaligned_array_assert()) {} + EIGEN_DEVICE_FUNC TensorStorage(DenseIndex size, const array& dimensions) : Base_(size, dimensions) {} // TensorStorage& operator=(const TensorStorage&) = default; }; @@ -82,24 +84,26 @@ class TensorStorage template class TensorStorage::type> { - T *m_data; - DSizes m_dimensions; + public: + typedef typename internal::compute_index_type::type Index; + typedef DSizes Dimensions; typedef TensorStorage::type> Self_; - public: - TensorStorage() : m_data(0), m_dimensions() {} - TensorStorage(internal::constructor_without_unaligned_array_assert) - : m_data(0), m_dimensions(internal::template repeat(0)) {} - TensorStorage(DenseIndex size, const array& dimensions) + + EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {} + EIGEN_DEVICE_FUNC TensorStorage(internal::constructor_without_unaligned_array_assert) + : m_data(0), m_dimensions(internal::template repeat(0)) {} + EIGEN_DEVICE_FUNC TensorStorage(Index size, const array& dimensions) : m_data(internal::conditional_aligned_new_auto(size)), m_dimensions(dimensions) { EIGEN_INTERNAL_TENSOR_STORAGE_CTOR_PLUGIN } - TensorStorage(const Self_& other) + + EIGEN_DEVICE_FUNC TensorStorage(const Self_& other) : m_data(internal::conditional_aligned_new_auto(internal::array_prod(other.m_dimensions))) , m_dimensions(other.m_dimensions) { internal::smart_copy(other.m_data, other.m_data+internal::array_prod(other.m_dimensions), m_data); } - Self_& operator=(const Self_& other) + EIGEN_DEVICE_FUNC Self_& operator=(const Self_& other) { if (this != &other) { Self_ tmp(other); @@ -108,15 +112,15 @@ class TensorStorage(m_data, internal::array_prod(m_dimensions)); } - void swap(Self_& other) + EIGEN_DEVICE_FUNC ~TensorStorage() { internal::conditional_aligned_delete_auto(m_data, internal::array_prod(m_dimensions)); } + EIGEN_DEVICE_FUNC void swap(Self_& other) { std::swap(m_data,other.m_data); std::swap(m_dimensions,other.m_dimensions); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DSizes& dimensions() const {return m_dimensions;} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {return m_dimensions;} - EIGEN_DEVICE_FUNC void resize(DenseIndex size, const array& nbDimensions) + EIGEN_DEVICE_FUNC void resize(Index size, const array& nbDimensions) { - const DenseIndex currentSz = internal::array_prod(m_dimensions); + const Index currentSz = internal::array_prod(m_dimensions); if(size != currentSz) { internal::conditional_aligned_delete_auto(m_data, currentSz); @@ -132,7 +136,11 @@ class TensorStorage