diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 10:13:08 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 10:13:08 -0800 |
commit | c94174b4fe76636ae5f027ad8e59023cd154d90d (patch) | |
tree | 3ec5cd282faf02704d6d5e29944e8e4d88b0779f /unsupported/Eigen/CXX11/src/Tensor/TensorRef.h | |
parent | 91dd53e54db5c85c37e05bce5af95d31ba337e34 (diff) |
Improved tensor references
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorRef.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorRef.h | 73 |
1 files changed, 71 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h index d43fb286e..0a87e67eb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t virtual const Scalar coeff(DenseIndex index) const { return m_impl.coeff(index); } - virtual Scalar& coeffRef(DenseIndex) { + virtual Scalar& coeffRef(DenseIndex /*index*/) { eigen_assert(false && "can't reference the coefficient of a rvalue"); return *reinterpret_cast<Scalar*>(dummy); }; @@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef enum { IsAligned = false, PacketAccess = false, + Layout = PlainObjectType::Layout, + CoordAccess = false, // to be implemented }; EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) { @@ -175,6 +177,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } @@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef const array<Index, NumIndices> indices{{firstIndex, otherIndices...}}; return coeff(indices); } + template<typename... IndexTypes> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) + { + const std::size_t NumIndices = (sizeof...(otherIndices) + 1); + const array<Index, NumIndices> indices{{firstIndex, otherIndices...}}; + return coeffRef(indices); + } #else EIGEN_DEVICE_FUNC @@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef indices[4] = i4; return coeff(indices); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1) + { + array<Index, 2> indices; + indices[0] = i0; + indices[1] = i1; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2) + { + array<Index, 3> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) + { + array<Index, 4> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4) + { + array<Index, 5> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + indices[4] = i4; + return coeffRef(indices); + } #endif template <std::size_t NumIndices> EIGEN_DEVICE_FUNC @@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef { const Dimensions& dims = this->dimensions(); Index index = 0; - if (PlainObjectType::Options&RowMajor) { + if (PlainObjectType::Options & RowMajor) { index += indices[0]; for (int i = 1; i < NumIndices; ++i) { index = index * dims[i] + indices[i]; @@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef } return m_evaluator->coeff(index); } + template <std::size_t NumIndices> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) + { + const Dimensions& dims = this->dimensions(); + Index index = 0; + if (PlainObjectType::Options & RowMajor) { + index += indices[0]; + for (int i = 1; i < NumIndices; ++i) { + index = index * dims[i] + indices[i]; + } + } else { + index += indices[NumIndices-1]; + for (int i = NumIndices-2; i >= 0; --i) { + index = index * dims[i] + indices[i]; + } + } + return m_evaluator->coeffRef(index); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const @@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device> enum { IsAligned = false, PacketAccess = false, + Layout = TensorRef<Derived>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) |