From debc97821c775518afd54e05e19dec9eb0c3bde1 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 28 Oct 2014 23:10:13 -0700 Subject: Added support for tensor references --- unsupported/Eigen/CXX11/src/Tensor/TensorRef.h | 360 +++++++++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorRef.h (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorRef.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h new file mode 100644 index 000000000..db2027a5f --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -0,0 +1,360 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H +#define EIGEN_CXX11_TENSOR_TENSOR_REF_H + +namespace Eigen { + +namespace internal { + +template +class TensorLazyBaseEvaluator { + public: + TensorLazyBaseEvaluator() : m_refcount(0) { } + virtual ~TensorLazyBaseEvaluator() { } + + virtual const Dimensions& dimensions() const = 0; + virtual const Scalar* data() const = 0; + + virtual const Scalar coeff(DenseIndex index) const = 0; + virtual Scalar& coeffRef(DenseIndex index) = 0; + + void incrRefCount() { ++m_refcount; } + void decrRefCount() { --m_refcount; } + int refCount() const { return m_refcount; } + + private: + // No copy, no assigment; + TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other); + TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other); + + int m_refcount; +}; + +static char dummy[8]; + +template +class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator::Scalar> { + public: + // typedef typename TensorEvaluator::Dimensions Dimensions; + typedef typename TensorEvaluator::Scalar Scalar; + + TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device) { + m_dims = m_impl.dimensions(); + m_impl.evalSubExprsIfNeeded(NULL); + } + virtual ~TensorLazyEvaluatorReadOnly() { + m_impl.cleanup(); + } + + virtual const Dimensions& dimensions() const { + return m_dims; + } + virtual const Scalar* data() const { + return m_impl.data(); + } + + virtual const Scalar coeff(DenseIndex index) const { + return m_impl.coeff(index); + } + virtual Scalar& coeffRef(DenseIndex index) { + eigen_assert(false && "can't reference the coefficient of a rvalue"); + return *reinterpret_cast(dummy); + }; + + protected: + TensorEvaluator m_impl; + Dimensions m_dims; +}; + +template +class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly { + public: + typedef TensorLazyEvaluatorReadOnly Base; + typedef typename Base::Scalar Scalar; + + TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) { + } + virtual ~TensorLazyEvaluatorWritable() { + } + + virtual Scalar& coeffRef(DenseIndex index) { + return this->m_impl.coeffRef(index); + } +}; + +template +class TensorLazyEvaluator : public internal::conditional::value), + TensorLazyEvaluatorWritable, + TensorLazyEvaluatorReadOnly >::type { + public: + typedef typename internal::conditional::value), + TensorLazyEvaluatorWritable, + TensorLazyEvaluatorReadOnly >::type Base; + typedef typename Base::Scalar Scalar; + + TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) { + } + virtual ~TensorLazyEvaluator() { + } +}; + +} // namespace internal + + +/** \class TensorRef + * \ingroup CXX11_Tensor_Module + * + * \brief A reference to a tensor expression + * The expression will be evaluated lazily (as much as possible). + * + */ +template class TensorRef : public TensorBase > +{ + public: + typedef TensorRef Self; + typedef typename PlainObjectType::Base Base; + typedef typename Eigen::internal::nested::type Nested; + typedef typename internal::traits::StorageKind StorageKind; + typedef typename internal::traits::Index Index; + typedef typename internal::traits::Scalar Scalar; + typedef typename internal::packet_traits::type Packet; + typedef typename NumTraits::Real RealScalar; + typedef typename Base::CoeffReturnType CoeffReturnType; + typedef Scalar* PointerType; + typedef PointerType PointerArgType; + + static const Index NumIndices = PlainObjectType::NumIndices; + typedef typename PlainObjectType::Dimensions Dimensions; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) { + } + + template + EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator(expr, DefaultDevice())) { + m_evaluator->incrRefCount(); + } + + template + EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) { + unrefEvaluator(); + m_evaluator = new internal::TensorLazyEvaluator(expr, DefaultDevice()); + m_evaluator->incrRefCount(); + return *this; + } + + ~TensorRef() { + unrefEvaluator(); + } + + TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) { + eigen_assert(m_evaluator->refCount() > 0); + m_evaluator->incrRefCount(); + } + + TensorRef& operator = (const TensorRef& other) { + if (this != &other) { + unrefEvaluator(); + m_evaluator = other.m_evaluator; + eigen_assert(m_evaluator->refCount() > 0); + m_evaluator->incrRefCount(); + } + return *this; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index index) const + { + return m_evaluator->coeff(index); + } + +#ifdef EIGEN_HAS_VARIADIC_TEMPLATES + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const + { + const std::size_t NumIndices = (sizeof...(otherIndices) + 1); + const array indices{{firstIndex, otherIndices...}}; + return coeff(indices); + } +#else + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const + { + array indices; + indices[0] = i0; + indices[1] = i1; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + indices[4] = i4; + return coeff(indices); + } +#endif + + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar coeff(const array& indices) const + { + const Dimensions& dims = this->dimensions(); + Index index = 0; + if (PlainObjectType::Options&RowMajor) { + index += indices[0]; + for (int i = 1; i < NumIndices; ++i) { + index = index * dims[i] + indices[i]; + } + } else { + index += indices[NumIndices-1]; + for (int i = NumIndices-2; i >= 0; --i) { + index = index * dims[i] + indices[i]; + } + } + return m_evaluator->coeff(index); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar coeff(Index index) const + { + return m_evaluator->coeff(index); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) + { + return m_evaluator->coeffRef(index); + } + + private: + EIGEN_STRONG_INLINE void unrefEvaluator() { + if (m_evaluator) { + m_evaluator->decrRefCount(); + if (m_evaluator->refCount() == 0) { + delete m_evaluator; + } + } + } + + internal::TensorLazyBaseEvaluator* m_evaluator; +}; + + +// evaluator for rvalues +template +struct TensorEvaluator, Device> +{ + typedef typename Derived::Index Index; + typedef typename Derived::Scalar Scalar; + typedef typename Derived::Packet Packet; + typedef typename Derived::Scalar CoeffReturnType; + typedef typename Derived::Packet PacketReturnType; + typedef typename Derived::Dimensions Dimensions; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef& m, const Device&) + : m_ref(m) + { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + return true; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_ref.coeff(index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return m_ref.coeffRef(index); + } + + Scalar* data() const { return m_ref.data(); } + + protected: + TensorRef m_ref; +}; + + +// evaluator for lvalues +template +struct TensorEvaluator, Device> : public TensorEvaluator, Device> +{ + typedef typename Derived::Index Index; + typedef typename Derived::Scalar Scalar; + typedef typename Derived::Packet Packet; + typedef typename Derived::Scalar CoeffReturnType; + typedef typename Derived::Packet PacketReturnType; + typedef typename Derived::Dimensions Dimensions; + + typedef TensorEvaluator, Device> Base; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef& m, const Device& d) : Base(m, d) + { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return this->m_ref.coeffRef(index); + } +}; + + + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H -- cgit v1.2.3