aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-28 17:46:05 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-28 17:46:05 -0700
commitbc40d4522c56fdf861fcdab28f4b7db609d8065e (patch)
tree82b4f2497c79da6a7f77e46e585ce7c61632fbdc /unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
parent1187bb65ad196161a07f4e0125e478d022ea1b08 (diff)
Const correctness in TensorMap<const Tensor<T, ...>> expressions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMap.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMap.h66
1 files changed, 40 insertions, 26 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
index 395cdf9c8..b28cd822f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
@@ -42,13 +42,27 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename Base::CoeffReturnType CoeffReturnType;
- /* typedef typename internal::conditional<
- bool(internal::is_lvalue<PlainObjectType>::value),
- Scalar *,
- const Scalar *>::type
- PointerType;*/
typedef typename MakePointer_<Scalar>::Type PointerType;
- typedef PointerType PointerArgType;
+ typedef typename MakePointer_<Scalar>::ConstType PointerConstType;
+
+ // WARN: PointerType still can be a pointer to const (const Scalar*), for
+ // example in TensorMap<Tensor<const Scalar, ...>> expression. This type of
+ // expression should be illegal, but adding this restriction is not possible
+ // in practice (see https://bitbucket.org/eigen/eigen/pull-requests/488).
+ typedef typename internal::conditional<
+ bool(internal::is_lvalue<PlainObjectType>::value),
+ PointerType, // use simple pointer in lvalue expressions
+ PointerConstType // use const pointer in rvalue expressions
+ >::type StoragePointerType;
+
+ // If TensorMap was constructed over rvalue expression (e.g. const Tensor),
+ // we should return a reference to const from operator() (and others), even
+ // if TensorMap itself is not const.
+ typedef typename internal::conditional<
+ bool(internal::is_lvalue<PlainObjectType>::value),
+ Scalar&,
+ const Scalar&
+ >::type StorageRefType;
static const int Options = Options_;
@@ -63,47 +77,47 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
};
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr) : m_data(dataPtr), m_dimensions() {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#if EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#else
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
+ EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#endif
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, const array<Index, NumIndices>& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
template <typename Dimensions>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, const Dimensions& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
@@ -120,9 +134,9 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE PointerType data() { return m_data; }
+ EIGEN_STRONG_INLINE StoragePointerType data() { return m_data; }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const PointerType data() const { return m_data; }
+ EIGEN_STRONG_INLINE PointerConstType data() const { return m_data; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
@@ -213,7 +227,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
#endif
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
+ EIGEN_STRONG_INLINE StorageRefType operator()(const array<Index, NumIndices>& indices)
{
// eigen_assert(checkIndexRange(indices));
if (PlainObjectType::Options&RowMajor) {
@@ -226,14 +240,14 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()()
+ EIGEN_STRONG_INLINE StorageRefType operator()()
{
EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
return m_data[0];
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index index)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index index)
{
eigen_internal_assert(index >= 0 && index < size());
return m_data[index];
@@ -241,7 +255,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
#if EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
{
static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
eigen_assert(internal::all((Eigen::NumTraits<Index>::highest() >= otherIndices)...));
@@ -256,7 +270,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
#else
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1)
{
if (PlainObjectType::Options&RowMajor) {
const Index index = i1 + i0 * m_dimensions[1];
@@ -267,7 +281,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2)
{
if (PlainObjectType::Options&RowMajor) {
const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
@@ -278,7 +292,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3)
{
if (PlainObjectType::Options&RowMajor) {
const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
@@ -289,7 +303,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
+ EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
{
if (PlainObjectType::Options&RowMajor) {
const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
@@ -320,7 +334,7 @@ template<typename PlainObjectType, int Options_, template <class> class MakePoin
}
private:
- typename MakePointer_<Scalar>::Type m_data;
+ StoragePointerType m_data;
Dimensions m_dimensions;
};