From c94174b4fe76636ae5f027ad8e59023cd154d90d Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 14 Jan 2015 10:13:08 -0800 Subject: Improved tensor references --- unsupported/Eigen/CXX11/src/Tensor/TensorRef.h | 73 +++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorRef.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h index d43fb286e..0a87e67eb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator(dummy); }; @@ -137,6 +137,8 @@ template class TensorRef : public TensorBase class TensorRef : public TensorBasedimensions().size(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } EIGEN_DEVICE_FUNC @@ -197,6 +201,13 @@ template class TensorRef : public TensorBase indices{{firstIndex, otherIndices...}}; return coeff(indices); } + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) + { + const std::size_t NumIndices = (sizeof...(otherIndices) + 1); + const array indices{{firstIndex, otherIndices...}}; + return coeffRef(indices); + } #else EIGEN_DEVICE_FUNC @@ -237,6 +248,44 @@ template class TensorRef : public TensorBase indices; + indices[0] = i0; + indices[1] = i1; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + indices[4] = i4; + return coeffRef(indices); + } #endif template EIGEN_DEVICE_FUNC @@ -244,7 +293,7 @@ template class TensorRef : public TensorBasedimensions(); Index index = 0; - if (PlainObjectType::Options&RowMajor) { + if (PlainObjectType::Options & RowMajor) { index += indices[0]; for (int i = 1; i < NumIndices; ++i) { index = index * dims[i] + indices[i]; @@ -257,6 +306,24 @@ template class TensorRef : public TensorBasecoeff(index); } + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(const array& indices) + { + const Dimensions& dims = this->dimensions(); + Index index = 0; + if (PlainObjectType::Options & RowMajor) { + index += indices[0]; + for (int i = 1; i < NumIndices; ++i) { + index = index * dims[i] + indices[i]; + } + } else { + index += indices[NumIndices-1]; + for (int i = NumIndices-2; i >= 0; --i) { + index = index * dims[i] + indices[i]; + } + } + return m_evaluator->coeffRef(index); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const @@ -298,6 +365,8 @@ struct TensorEvaluator, Device> enum { IsAligned = false, PacketAccess = false, + Layout = TensorRef::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef& m, const Device&) -- cgit v1.2.3