aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 10:13:08 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 10:13:08 -0800
commitc94174b4fe76636ae5f027ad8e59023cd154d90d (patch)
tree3ec5cd282faf02704d6d5e29944e8e4d88b0779f
parent91dd53e54db5c85c37e05bce5af95d31ba337e34 (diff)
Improved tensor references
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorRef.h73
-rw-r--r--unsupported/test/cxx11_tensor_ref.cpp16
2 files changed, 87 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
index d43fb286e..0a87e67eb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
@@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t
virtual const Scalar coeff(DenseIndex index) const {
return m_impl.coeff(index);
}
- virtual Scalar& coeffRef(DenseIndex) {
+ virtual Scalar& coeffRef(DenseIndex /*index*/) {
eigen_assert(false && "can't reference the coefficient of a rvalue");
return *reinterpret_cast<Scalar*>(dummy);
};
@@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
enum {
IsAligned = false,
PacketAccess = false,
+ Layout = PlainObjectType::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
@@ -175,6 +177,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
@@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
return coeff(indices);
}
+ template<typename... IndexTypes> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
+ {
+ const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
+ const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
+ return coeffRef(indices);
+ }
#else
EIGEN_DEVICE_FUNC
@@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
indices[4] = i4;
return coeff(indices);
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
+ {
+ array<Index, 2> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
+ {
+ array<Index, 3> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
+ {
+ array<Index, 4> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ return coeffRef(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
+ {
+ array<Index, 5> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ indices[4] = i4;
+ return coeffRef(indices);
+ }
#endif
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
@@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
{
const Dimensions& dims = this->dimensions();
Index index = 0;
- if (PlainObjectType::Options&RowMajor) {
+ if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (int i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
@@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
}
return m_evaluator->coeff(index);
}
+ template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
+ {
+ const Dimensions& dims = this->dimensions();
+ Index index = 0;
+ if (PlainObjectType::Options & RowMajor) {
+ index += indices[0];
+ for (int i = 1; i < NumIndices; ++i) {
+ index = index * dims[i] + indices[i];
+ }
+ } else {
+ index += indices[NumIndices-1];
+ for (int i = NumIndices-2; i >= 0; --i) {
+ index = index * dims[i] + indices[i];
+ }
+ }
+ return m_evaluator->coeffRef(index);
+ }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
@@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device>
enum {
IsAligned = false,
PacketAccess = false,
+ Layout = TensorRef<Derived>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp
index 4ff94a059..aa369f278 100644
--- a/unsupported/test/cxx11_tensor_ref.cpp
+++ b/unsupported/test/cxx11_tensor_ref.cpp
@@ -181,6 +181,21 @@ static void test_ref_in_expr()
}
+static void test_coeff_ref()
+{
+ Tensor<float, 5> tensor(2,3,5,7,11);
+ tensor.setRandom();
+ Tensor<float, 5> original = tensor;
+
+ TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
+ slice.coeffRef(0, 0, 0, 0) = 1.0f;
+ slice.coeffRef(1, 0, 0, 0) += 2.0f;
+
+ VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
+ VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
+}
+
+
void test_cxx11_tensor_ref()
{
CALL_SUBTEST(test_simple_lvalue_ref());
@@ -189,4 +204,5 @@ void test_cxx11_tensor_ref()
CALL_SUBTEST(test_slice());
CALL_SUBTEST(test_ref_of_ref());
CALL_SUBTEST(test_ref_in_expr());
+ CALL_SUBTEST(test_coeff_ref());
}