diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-26 14:29:26 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-26 14:29:26 -0700 |
commit | 1c8312c811344beaa06f7ae9258f66c38337c607 (patch) | |
tree | 4436a04ce900a997aa6adf9d42d0c6dab0a07fac /unsupported/Eigen/CXX11/src/Tensor | |
parent | 1f4c98abb1634bdbdd6583b55ba36dcc09ef5773 (diff) |
Started to add support for tensors of rank 0
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/Tensor.h | 31 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorIO.h | 5 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h | 12 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h | 12 |
4 files changed, 57 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index 3ac465d24..0df1345c2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -140,6 +140,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp } #endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff() const + { + EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + return m_storage.data()[0]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const { eigen_internal_assert(index >= 0 && index < size()); @@ -174,6 +180,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp } #endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef() + { + EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + return m_storage.data()[0]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { eigen_internal_assert(index >= 0 && index < size()); @@ -234,6 +246,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp return coeff(index); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()() const + { + EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + return coeff(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator[](Index index) const { // The bracket operator is only for vectors, use the parenthesis operator instead. @@ -295,6 +313,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp return coeffRef(index); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()() + { + EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + return coeffRef(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator[](Index index) { // The bracket operator is only for vectors, use the parenthesis operator instead @@ -433,6 +457,13 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp resize(dims); } + EIGEN_DEVICE_FUNC + void resize() + { + EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + // Nothing to do: rank 0 tensors have fixed size + } + /** Custom Dimension */ #ifdef EIGEN_HAS_SFINAE template<typename CustomDimension, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h index 3b6f2c730..38a833f82 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h @@ -33,7 +33,10 @@ std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccesso const Index total_size = internal::array_prod(tensor.dimensions()); // Print the tensor as a 1d vector or a 2d matrix. - if (internal::array_size<Dimensions>::value == 1) { + static const int rank = internal::array_size<Dimensions>::value; + if (rank == 0) { + os << tensor.coeff(0); + } else if (rank == 1) { Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size); os << array; } else { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h b/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h index 4303e3536..ad2a1e6ac 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h @@ -55,6 +55,18 @@ struct Initializer<Derived, 1> { } }; +template <typename Derived> +struct Initializer<Derived, 0> { + typedef typename traits<Derived>::Scalar InitList; + + static void run(TensorEvaluator<Derived, DefaultDevice>& tensor, + Eigen::array<typename traits<Derived>::Index, traits<Derived>::NumDimensions>*/* indices*/, + const InitList& v) { + tensor.coeffRef(0) = v; + } +}; + + template <typename Derived, int N> void initialize_tensor(TensorEvaluator<Derived, DefaultDevice>& tensor, const typename Initializer<Derived, traits<Derived>::NumDimensions>::InitList& vals) { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h index 9e4cf039d..ee6f14b8f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h @@ -71,7 +71,11 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_> typedef DSizes<IndexType, NumIndices_> Dimensions; typedef TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_> Self; - EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {} + EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() { + if (NumIndices_ == 0) { + m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1); + } + } EIGEN_DEVICE_FUNC TensorStorage(internal::constructor_without_unaligned_array_assert) : m_data(0), m_dimensions(internal::template repeat<NumIndices_, Index>(0)) {} EIGEN_DEVICE_FUNC TensorStorage(Index size, const array<Index, NumIndices_>& dimensions) @@ -101,13 +105,17 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_> EIGEN_DEVICE_FUNC void resize(Index size, const array<Index, NumIndices_>& nbDimensions) { + eigen_assert(size >= 1); const Index currentSz = internal::array_prod(m_dimensions); if(size != currentSz) { internal::conditional_aligned_delete_auto<T,(Options_&DontAlign)==0>(m_data, currentSz); if (size) m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(size); - else + else if (NumIndices_ == 0) { + m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1); + } + else m_data = 0; EIGEN_INTERNAL_DENSE_STORAGE_CTOR_PLUGIN } |