diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-02-26 11:13:42 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-02-26 11:13:42 -0800 |
commit | 57154fdb32a4853bff458f8014b037d5e41b9858 (patch) | |
tree | be601f42c4691458bc9d52cac259b755e9504e9a /unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h | |
parent | 2fffe69b1be0b4448c5105edf4aeac22937ae5dc (diff) |
Can now use the tensor 'reverse' operation as a lvalue
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h | 91 |
1 files changed, 79 insertions, 12 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h index ad21e966b..16bef2ad3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h @@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1, } // end namespace internal - - - template<typename ReverseDimensions, typename XprType> class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, - XprType>, ReadOnlyAccessors> + XprType>, WriteAccessors> { public: typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar; @@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, StorageKind; typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr, - const ReverseDimensions& reverse_dims) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp( + const XprType& expr, const ReverseDimensions& reverse_dims) : m_xpr(expr), m_reverse_dims(reverse_dims) {} EIGEN_DEVICE_FUNC @@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, const typename internal::remove_all<typename XprType::Nested>::type& expression() const { return m_xpr; } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other) + { + typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + + template<typename OtherDerived> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + protected: typename XprType::Nested m_xpr; const ReverseDimensions m_reverse_dims; }; - // Eval as rvalue template<typename ReverseDimensions, typename ArgType, typename Device> struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device> @@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device m_impl.cleanup(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const - { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex( + Index index) const { eigen_assert(index < dimensions().TotalSize()); Index inputIndex = 0; if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { @@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device } else { inputIndex += index; } - return m_impl.coeff(inputIndex); } else { for (int i = 0; i < NumDims - 1; ++i) { Index idx = index / m_strides[i]; @@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device } else { inputIndex += index; } - return m_impl.coeff(inputIndex); } + return inputIndex; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff( + Index index) const { + return m_impl.coeff(reverseIndex(index)); } template<int LoadMode> @@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device ReverseDimensions m_reverse; }; +// Eval as lvalue + +template <typename ReverseDimensions, typename ArgType, typename Device> +struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device> + : public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, + Device> { + typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, + Device> Base; + typedef TensorReverseOp<ReverseDimensions, ArgType> XprType; + typedef typename XprType::Index Index; + static const int NumDims = internal::array_size<ReverseDimensions>::value; + typedef DSizes<Index, NumDims> Dimensions; + + enum { + IsAligned = false, + PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented + }; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, + const Device& device) + : Base(op, device) {} + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const Dimensions& dimensions() const { return this->m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return this->m_impl.coeffRef(this->reverseIndex(index)); + } + + template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketReturnType& x) { + const int packetSize = internal::unpacket_traits<PacketReturnType>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+packetSize-1 < dimensions().TotalSize()); + // This code is pilfered from TensorMorphing.h + EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize]; + internal::pstore<CoeffReturnType, PacketReturnType>(values, x); + for (int i = 0; i < packetSize; ++i) { + this->coeffRef(index+i) = values[i]; + } + } + +}; -} // end namespace Eigen +} // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H |