aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-28 23:10:13 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-28 23:10:13 -0700
commitdebc97821c775518afd54e05e19dec9eb0c3bde1 (patch)
tree75da8f87467dc352d4562007a03f24ef79e3af6a /unsupported/Eigen
parentf786897e4b96737767effc85bedb78f06dc46dc5 (diff)
Added support for tensor references
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r--unsupported/Eigen/CXX11/Tensor2
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorRef.h360
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h40
4 files changed, 403 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor
index 47447f446..c36db96ec 100644
--- a/unsupported/Eigen/CXX11/Tensor
+++ b/unsupported/Eigen/CXX11/Tensor
@@ -76,6 +76,8 @@
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
+
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index 67f478822..a72e11215 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -15,6 +15,7 @@ namespace Eigen {
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
+template<typename PlainObjectType> class TensorRef;
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
new file mode 100644
index 000000000..db2027a5f
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
@@ -0,0 +1,360 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
+#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
+
+namespace Eigen {
+
+namespace internal {
+
+template <typename Dimensions, typename Scalar>
+class TensorLazyBaseEvaluator {
+ public:
+ TensorLazyBaseEvaluator() : m_refcount(0) { }
+ virtual ~TensorLazyBaseEvaluator() { }
+
+ virtual const Dimensions& dimensions() const = 0;
+ virtual const Scalar* data() const = 0;
+
+ virtual const Scalar coeff(DenseIndex index) const = 0;
+ virtual Scalar& coeffRef(DenseIndex index) = 0;
+
+ void incrRefCount() { ++m_refcount; }
+ void decrRefCount() { --m_refcount; }
+ int refCount() const { return m_refcount; }
+
+ private:
+ // No copy, no assigment;
+ TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
+ TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
+
+ int m_refcount;
+};
+
+static char dummy[8];
+
+template <typename Dimensions, typename Expr, typename Device>
+class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
+ public:
+ // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
+ typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
+
+ TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device) {
+ m_dims = m_impl.dimensions();
+ m_impl.evalSubExprsIfNeeded(NULL);
+ }
+ virtual ~TensorLazyEvaluatorReadOnly() {
+ m_impl.cleanup();
+ }
+
+ virtual const Dimensions& dimensions() const {
+ return m_dims;
+ }
+ virtual const Scalar* data() const {
+ return m_impl.data();
+ }
+
+ virtual const Scalar coeff(DenseIndex index) const {
+ return m_impl.coeff(index);
+ }
+ virtual Scalar& coeffRef(DenseIndex index) {
+ eigen_assert(false && "can't reference the coefficient of a rvalue");
+ return *reinterpret_cast<Scalar*>(dummy);
+ };
+
+ protected:
+ TensorEvaluator<Expr, Device> m_impl;
+ Dimensions m_dims;
+};
+
+template <typename Dimensions, typename Expr, typename Device>
+class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
+ public:
+ typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
+ typedef typename Base::Scalar Scalar;
+
+ TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
+ }
+ virtual ~TensorLazyEvaluatorWritable() {
+ }
+
+ virtual Scalar& coeffRef(DenseIndex index) {
+ return this->m_impl.coeffRef(index);
+ }
+};
+
+template <typename Dimensions, typename Expr, typename Device>
+class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
+ TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
+ TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
+ public:
+ typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
+ TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
+ TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
+ typedef typename Base::Scalar Scalar;
+
+ TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
+ }
+ virtual ~TensorLazyEvaluator() {
+ }
+};
+
+} // namespace internal
+
+
+/** \class TensorRef
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief A reference to a tensor expression
+ * The expression will be evaluated lazily (as much as possible).
+ *
+ */
+template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
+{
+ public:
+ typedef TensorRef<PlainObjectType> Self;
+ typedef typename PlainObjectType::Base Base;
+ typedef typename Eigen::internal::nested<Self>::type Nested;
+ typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
+ typedef typename internal::traits<PlainObjectType>::Index Index;
+ typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
+ typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ typedef typename Base::CoeffReturnType CoeffReturnType;
+ typedef Scalar* PointerType;
+ typedef PointerType PointerArgType;
+
+ static const Index NumIndices = PlainObjectType::NumIndices;
+ typedef typename PlainObjectType::Dimensions Dimensions;
+
+ enum {
+ IsAligned = false,
+ PacketAccess = false,
+ };
+
+ EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
+ }
+
+ template <typename Expression>
+ EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
+ m_evaluator->incrRefCount();
+ }
+
+ template <typename Expression>
+ EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
+ unrefEvaluator();
+ m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
+ m_evaluator->incrRefCount();
+ return *this;
+ }
+
+ ~TensorRef() {
+ unrefEvaluator();
+ }
+
+ TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
+ eigen_assert(m_evaluator->refCount() > 0);
+ m_evaluator->incrRefCount();
+ }
+
+ TensorRef& operator = (const TensorRef& other) {
+ if (this != &other) {
+ unrefEvaluator();
+ m_evaluator = other.m_evaluator;
+ eigen_assert(m_evaluator->refCount() > 0);
+ m_evaluator->incrRefCount();
+ }
+ return *this;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
+ {
+ return m_evaluator->coeff(index);
+ }
+
+#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
+ template<typename... IndexTypes> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
+ {
+ const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
+ const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
+ return coeff(indices);
+ }
+#else
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
+ {
+ array<Index, 2> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ return coeff(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
+ {
+ array<Index, 3> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ return coeff(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
+ {
+ array<Index, 4> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ return coeff(indices);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
+ {
+ array<Index, 5> indices;
+ indices[0] = i0;
+ indices[1] = i1;
+ indices[2] = i2;
+ indices[3] = i3;
+ indices[4] = i4;
+ return coeff(indices);
+ }
+#endif
+
+ template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
+ {
+ const Dimensions& dims = this->dimensions();
+ Index index = 0;
+ if (PlainObjectType::Options&RowMajor) {
+ index += indices[0];
+ for (int i = 1; i < NumIndices; ++i) {
+ index = index * dims[i] + indices[i];
+ }
+ } else {
+ index += indices[NumIndices-1];
+ for (int i = NumIndices-2; i >= 0; --i) {
+ index = index * dims[i] + indices[i];
+ }
+ }
+ return m_evaluator->coeff(index);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
+ {
+ return m_evaluator->coeff(index);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
+ {
+ return m_evaluator->coeffRef(index);
+ }
+
+ private:
+ EIGEN_STRONG_INLINE void unrefEvaluator() {
+ if (m_evaluator) {
+ m_evaluator->decrRefCount();
+ if (m_evaluator->refCount() == 0) {
+ delete m_evaluator;
+ }
+ }
+ }
+
+ internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
+};
+
+
+// evaluator for rvalues
+template<typename Derived, typename Device>
+struct TensorEvaluator<const TensorRef<Derived>, Device>
+{
+ typedef typename Derived::Index Index;
+ typedef typename Derived::Scalar Scalar;
+ typedef typename Derived::Packet Packet;
+ typedef typename Derived::Scalar CoeffReturnType;
+ typedef typename Derived::Packet PacketReturnType;
+ typedef typename Derived::Dimensions Dimensions;
+
+ enum {
+ IsAligned = false,
+ PacketAccess = false,
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
+ : m_ref(m)
+ { }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
+ return true;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
+ return m_ref.coeff(index);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
+ return m_ref.coeffRef(index);
+ }
+
+ Scalar* data() const { return m_ref.data(); }
+
+ protected:
+ TensorRef<Derived> m_ref;
+};
+
+
+// evaluator for lvalues
+template<typename Derived, typename Device>
+struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
+{
+ typedef typename Derived::Index Index;
+ typedef typename Derived::Scalar Scalar;
+ typedef typename Derived::Packet Packet;
+ typedef typename Derived::Scalar CoeffReturnType;
+ typedef typename Derived::Packet PacketReturnType;
+ typedef typename Derived::Dimensions Dimensions;
+
+ typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
+
+ enum {
+ IsAligned = false,
+ PacketAccess = false,
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
+ { }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
+ return this->m_ref.coeffRef(index);
+ }
+};
+
+
+
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
index 5940a8cf1..5c0f78489 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
@@ -84,6 +84,20 @@ struct traits<TensorMap<PlainObjectType, Options_> >
};
};
+template<typename PlainObjectType>
+struct traits<TensorRef<PlainObjectType> >
+ : public traits<PlainObjectType>
+{
+ typedef traits<PlainObjectType> BaseTraits;
+ typedef typename BaseTraits::Scalar Scalar;
+ typedef typename BaseTraits::StorageKind StorageKind;
+ typedef typename BaseTraits::Index Index;
+ enum {
+ Options = BaseTraits::Options,
+ Flags = ((BaseTraits::Flags | LvalueBit) & ~AlignedBit) | (Options&Aligned ? AlignedBit : 0),
+ };
+};
+
template<typename _Scalar, std::size_t NumIndices_, int Options>
struct eval<Tensor<_Scalar, NumIndices_, Options>, Eigen::Dense>
@@ -121,6 +135,19 @@ struct eval<const TensorMap<PlainObjectType, Options>, Eigen::Dense>
typedef const TensorMap<PlainObjectType, Options>& type;
};
+template<typename PlainObjectType>
+struct eval<TensorRef<PlainObjectType>, Eigen::Dense>
+{
+ typedef const TensorRef<PlainObjectType>& type;
+};
+
+template<typename PlainObjectType>
+struct eval<const TensorRef<PlainObjectType>, Eigen::Dense>
+{
+ typedef const TensorRef<PlainObjectType>& type;
+};
+
+
template <typename Scalar_, std::size_t NumIndices_, int Options_>
struct nested<Tensor<Scalar_, NumIndices_, Options_>, 1, typename eval<Tensor<Scalar_, NumIndices_, Options_> >::type>
{
@@ -145,6 +172,7 @@ struct nested<const TensorFixedSize<Scalar_, Dimensions, Options>, 1, typename e
typedef const TensorFixedSize<Scalar_, Dimensions, Options>& type;
};
+
template <typename PlainObjectType, int Options>
struct nested<TensorMap<PlainObjectType, Options>, 1, typename eval<TensorMap<PlainObjectType, Options> >::type>
{
@@ -157,6 +185,18 @@ struct nested<const TensorMap<PlainObjectType, Options>, 1, typename eval<Tensor
typedef const TensorMap<PlainObjectType, Options>& type;
};
+template <typename PlainObjectType>
+struct nested<TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type>
+{
+ typedef const TensorRef<PlainObjectType>& type;
+};
+
+template <typename PlainObjectType>
+struct nested<const TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type>
+{
+ typedef const TensorRef<PlainObjectType>& type;
+};
+
} // end namespace internal
} // end namespace Eigen