aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 10:13:08 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 10:13:08 -0800
commitc94174b4fe76636ae5f027ad8e59023cd154d90d (patch)
tree3ec5cd282faf02704d6d5e29944e8e4d88b0779f /unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
parent91dd53e54db5c85c37e05bce5af95d31ba337e34 (diff)
Improved tensor references
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorRef.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorRef.h73
1 files changed, 71 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
index d43fb286e..0a87e67eb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
@@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t
virtual const Scalar coeff(DenseIndex index) const {
return m_impl.coeff(index);
}
- virtual Scalar& coeffRef(DenseIndex) {
+ virtual Scalar& coeffRef(DenseIndex /*index*/) {
eigen_assert(false && "can't reference the coefficient of a rvalue");
return *reinterpret_cast<Scalar*>(dummy);
};
@@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
enum {
IsAligned = false,
PacketAccess = false,
+ Layout = PlainObjectType::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
@@ -175,6 +177,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
@@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
return coeff(indices);
}
+ template<typename... IndexTypes> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
+ {
+ const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
+ const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
+ return coeffRef(indices);
+ }
#else
EIGEN_DEVICE_FUNC
@@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
indices[4] = i4;
return coeff(indices);
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
+ {
+ array<Index, 2> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
+ {
+ array<Index, 3> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
+ {
+ array<Index, 4> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
+ {
+ array<Index, 5> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ indices[4] = i4;
+ return coeffRef(indices);
+ }
#endif
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
@@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
{
const Dimensions& dims = this->dimensions();
Index index = 0;
- if (PlainObjectType::Options&RowMajor) {
+ if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (int i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
@@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
}
return m_evaluator->coeff(index);
}
+ template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
+ {
+ const Dimensions& dims = this->dimensions();
+ Index index = 0;
+ if (PlainObjectType::Options & RowMajor) {
+ index += indices[0];
+ for (int i = 1; i < NumIndices; ++i) {
+ index = index * dims[i] + indices[i];
+ }
+ } else {
+ index += indices[NumIndices-1];
+ for (int i = NumIndices-2; i >= 0; --i) {
+ index = index * dims[i] + indices[i];
+ }
+ }
+ return m_evaluator->coeffRef(index);
+ }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
@@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device>
enum {
IsAligned = false,
PacketAccess = false,
+ Layout = TensorRef<Derived>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)