diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-09 09:45:30 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-09 09:45:30 -0700 |
commit | a669052f12d6d71ba815764d6419726d64fef675 (patch) | |
tree | a087876a5b341c0c3f2380d3530579cfbeb25c1c /unsupported/Eigen | |
parent | 36a2b2e9dc9368356b3f327a1fb00616397c1e0e (diff) |
Improved support for rvalues in tensor expressions.
Diffstat (limited to 'unsupported/Eigen')
7 files changed, 71 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 932e5c82d..e447a5d40 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -22,7 +22,7 @@ namespace Eigen { */ template<typename Derived> -class TensorBase +class TensorBase<Derived, ReadOnlyAccessors> { public: typedef typename internal::traits<Derived>::Scalar Scalar; @@ -30,19 +30,6 @@ class TensorBase typedef Scalar CoeffReturnType; typedef typename internal::packet_traits<Scalar>::type PacketReturnType; - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setZero() { - return setConstant(Scalar(0)); - } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { - return derived() = constant(val); - } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setRandom() { - return derived() = random(); - } - // Nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> @@ -224,14 +211,53 @@ class TensorBase return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions); } + protected: + template <typename OtherDerived, int AccessLevel> friend class TensorBase; + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); } +}; + + +template<typename Derived> +class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyAccessors> { + public: + typedef typename internal::traits<Derived>::Scalar Scalar; + typedef typename internal::traits<Derived>::Index Index; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits<Scalar>::type PacketReturnType; + + template <typename OtherDerived, int AccessLevel> friend class TensorBase; + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setZero() { + return setConstant(Scalar(0)); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { + return derived() = this->constant(val); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setRandom() { + return derived() = this->random(); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Derived& operator+=(const OtherDerived& other) { + return derived() = TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Derived& operator-=(const OtherDerived& other) { + return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + // Select the device on which to evaluate the expression. template <typename DeviceType> TensorDevice<Derived, DeviceType> device(const DeviceType& device) { return TensorDevice<Derived, DeviceType>(device, derived()); } - protected: - template <typename OtherDerived> friend class TensorBase; + protected: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); } EIGEN_DEVICE_FUNC diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index d424df36e..d371eb76d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -35,6 +35,10 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + + enum { + Flags = 0, + }; }; template<typename Dimensions, typename LhsXprType, typename RhsXprType> diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h index ca2e0e562..501e9a522 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h @@ -35,6 +35,10 @@ struct traits<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> > typedef typename KernelXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + + enum { + Flags = 0, + }; }; template<typename Dimensions, typename InputXprType, typename KernelXprType> diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 60908ee94..de66da13f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -36,6 +36,10 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > typedef typename XprType::Scalar Scalar; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; + + enum { + Flags = 0, + }; }; } // end namespace internal @@ -153,6 +157,10 @@ struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + + enum { + Flags = 0, + }; }; template<typename BinaryOp, typename LhsXprType, typename RhsXprType> diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index b8833362c..1fb90478f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -15,7 +15,7 @@ namespace Eigen { template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor; template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize; template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap; -template<typename Derived> class TensorBase; +template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase; template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp; template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp; @@ -29,6 +29,10 @@ template<typename ExpressionType, typename DeviceType> class TensorDevice; // Move to internal? template<typename Derived> struct TensorEvaluator; +namespace internal { +template<typename Derived, typename OtherDerived, bool Vectorizable> struct TensorAssign; +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index 3e089fe1e..7d5f9271e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -21,7 +21,7 @@ namespace Eigen { */ namespace internal { template<typename XprType, typename NewDimensions> -struct traits<TensorReshapingOp<XprType, NewDimensions> > +struct traits<TensorReshapingOp<XprType, NewDimensions> > : public traits<XprType> { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename XprType::Scalar Scalar; @@ -81,6 +81,7 @@ template<typename ArgType, typename NewDimensions> struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> > { typedef TensorReshapingOp<ArgType, NewDimensions> XprType; + typedef NewDimensions Dimensions; enum { IsAligned = TensorEvaluator<ArgType>::IsAligned, @@ -95,7 +96,7 @@ struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> > typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - const NewDimensions& dimensions() const { return m_dimensions; } + const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 2de698a57..40f805741 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -52,7 +52,7 @@ struct traits<Tensor<Scalar_, NumIndices_, Options_> > typedef DenseIndex Index; enum { Options = Options_, - Flags = compute_tensor_flags<Scalar_, Options_>::ret, + Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit, }; }; @@ -63,6 +63,10 @@ struct traits<TensorFixedSize<Scalar_, Dimensions, Options_> > typedef Scalar_ Scalar; typedef Dense StorageKind; typedef DenseIndex Index; + enum { + Options = Options_, + Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit, + }; }; |