aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-06 11:18:37 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-06 11:18:37 -0700
commit0320f7e3a71406b9a03d1bab0d168fd76e63d457 (patch)
treefffaaacd58cb5088f66d868bbb172971aacf9b53 /unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
parentc0f2cb016e60b7dbde1d5946f42234a709a711f9 (diff)
Added support for fixed sized tensors.
Improved support for tensor expressions.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMap.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMap.h31
1 files changed, 25 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
index 7dec1e08d..bb0b39c5a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
@@ -43,24 +43,38 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
typedef Scalar* PointerType;
typedef PointerType PointerArgType;
+ // Fixed size plain object type only
+ /* EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr) {
+ // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
+ //EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ // todo: add assert to ensure we don't screw up here.
+ }*/
+
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions({{firstDimension}}) {
+ EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions({{firstDimension, otherDimensions...}}) {
+ EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#endif
+ inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions)
+ : m_data(dataPtr), m_dimensions(dimensions)
+ { }
+
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index size() const { return internal::array_prod(m_dimensions); }
+ EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar* data() { return m_data; }
EIGEN_DEVICE_FUNC
@@ -78,8 +92,13 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
{
static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
- const Index index = internal::tensor_index_linearization_helper<Index, PlainObjectType::NumIndices, PlainObjectType::NumIndices - 1, PlainObjectType::Options&RowMajor>::run(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}, m_dimensions);
- return m_data[index];
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ return m_data[index];
+ } else {
+ const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ return m_data[index];
+ }
}
#endif
@@ -93,7 +112,7 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
private:
typename PlainObjectType::Scalar* m_data;
- array<DenseIndex, PlainObjectType::NumIndices> m_dimensions;
+ typename PlainObjectType::Dimensions m_dimensions;
};
} // end namespace Eigen