diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMap.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMap.h | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h index 7dec1e08d..bb0b39c5a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h @@ -43,24 +43,38 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap typedef Scalar* PointerType; typedef PointerType PointerArgType; + // Fixed size plain object type only + /* EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr) { + // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. + //EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) + // todo: add assert to ensure we don't screw up here. + }*/ + EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions({{firstDimension}}) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #ifdef EIGEN_HAS_VARIADIC_TEMPLATES template<typename... IndexTypes> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions({{firstDimension, otherDimensions...}}) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #endif + inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions) + : m_data(dataPtr), m_dimensions(dimensions) + { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Index size() const { return internal::array_prod(m_dimensions); } + EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() { return m_data; } EIGEN_DEVICE_FUNC @@ -78,8 +92,13 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) { static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); - const Index index = internal::tensor_index_linearization_helper<Index, PlainObjectType::NumIndices, PlainObjectType::NumIndices - 1, PlainObjectType::Options&RowMajor>::run(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}, m_dimensions); - return m_data[index]; + if (PlainObjectType::Options&RowMajor) { + const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + return m_data[index]; + } else { + const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + return m_data[index]; + } } #endif @@ -93,7 +112,7 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap private: typename PlainObjectType::Scalar* m_data; - array<DenseIndex, PlainObjectType::NumIndices> m_dimensions; + typename PlainObjectType::Dimensions m_dimensions; }; } // end namespace Eigen |