diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 5 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h | 91 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_reverse.cpp | 35 |
3 files changed, 109 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index cfcf18e8e..201b0fc9e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -549,6 +549,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA chip(const Index offset, const Index dim) const { return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim); } + template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorReverseOp<const ReverseDimensions, Derived> + reverse(const ReverseDimensions& rev) const { + return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev); + } template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp<const Shuffle, Derived> shuffle(const Shuffle& shuffle) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h index ad21e966b..16bef2ad3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h @@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1, } // end namespace internal - - - template<typename ReverseDimensions, typename XprType> class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, - XprType>, ReadOnlyAccessors> + XprType>, WriteAccessors> { public: typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar; @@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, StorageKind; typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr, - const ReverseDimensions& reverse_dims) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp( + const XprType& expr, const ReverseDimensions& reverse_dims) : m_xpr(expr), m_reverse_dims(reverse_dims) {} EIGEN_DEVICE_FUNC @@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions, const typename internal::remove_all<typename XprType::Nested>::type& expression() const { return m_xpr; } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other) + { + typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + + template<typename OtherDerived> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + protected: typename XprType::Nested m_xpr; const ReverseDimensions m_reverse_dims; }; - // Eval as rvalue template<typename ReverseDimensions, typename ArgType, typename Device> struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device> @@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device m_impl.cleanup(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const - { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex( + Index index) const { eigen_assert(index < dimensions().TotalSize()); Index inputIndex = 0; if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { @@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device } else { inputIndex += index; } - return m_impl.coeff(inputIndex); } else { for (int i = 0; i < NumDims - 1; ++i) { Index idx = index / m_strides[i]; @@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device } else { inputIndex += index; } - return m_impl.coeff(inputIndex); } + return inputIndex; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff( + Index index) const { + return m_impl.coeff(reverseIndex(index)); } template<int LoadMode> @@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device ReverseDimensions m_reverse; }; +// Eval as lvalue + +template <typename ReverseDimensions, typename ArgType, typename Device> +struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device> + : public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, + Device> { + typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, + Device> Base; + typedef TensorReverseOp<ReverseDimensions, ArgType> XprType; + typedef typename XprType::Index Index; + static const int NumDims = internal::array_size<ReverseDimensions>::value; + typedef DSizes<Index, NumDims> Dimensions; + + enum { + IsAligned = false, + PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented + }; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, + const Device& device) + : Base(op, device) {} + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const Dimensions& dimensions() const { return this->m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return this->m_impl.coeffRef(this->reverseIndex(index)); + } + + template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketReturnType& x) { + const int packetSize = internal::unpacket_traits<PacketReturnType>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+packetSize-1 < dimensions().TotalSize()); + // This code is pilfered from TensorMorphing.h + EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize]; + internal::pstore<CoeffReturnType, PacketReturnType>(values, x); + for (int i = 0; i < packetSize; ++i) { + this->coeffRef(index+i) = values[i]; + } + } + +}; -} // end namespace Eigen +} // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H diff --git a/unsupported/test/cxx11_tensor_reverse.cpp b/unsupported/test/cxx11_tensor_reverse.cpp index 4c0be35da..f96c21fa3 100644 --- a/unsupported/test/cxx11_tensor_reverse.cpp +++ b/unsupported/test/cxx11_tensor_reverse.cpp @@ -94,7 +94,7 @@ static void test_simple_reverse() template <int DataLayout> -static void test_expr_reverse() +static void test_expr_reverse(bool LValue) { Tensor<float, 4, DataLayout> tensor(2,3,5,7); tensor.setRandom(); @@ -105,9 +105,12 @@ static void test_expr_reverse() dim_rev[2] = false; dim_rev[3] = true; - - Tensor<float, 4, DataLayout> expected; - expected = tensor.reverse(dim_rev); + Tensor<float, 4, DataLayout> expected(2, 3, 5, 7); + if (LValue) { + expected.reverse(dim_rev) = tensor; + } else { + expected = tensor.reverse(dim_rev); + } Tensor<float, 4, DataLayout> result(2,3,5,7); @@ -117,8 +120,13 @@ static void test_expr_reverse() array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}}; for (int i = 0; i < 5; ++i) { - result.slice(dst_slice_start, dst_slice_dim) = - tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev); + if (LValue) { + result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) = + tensor.slice(src_slice_start, src_slice_dim); + } else { + result.slice(dst_slice_start, dst_slice_dim) = + tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev); + } src_slice_start[2] += 1; dst_slice_start[2] += 1; } @@ -141,8 +149,13 @@ static void test_expr_reverse() dst_slice_start[2] = 0; result.setRandom(); for (int i = 0; i < 5; ++i) { - result.slice(dst_slice_start, dst_slice_dim) = - tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim); + if (LValue) { + result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) = + tensor.slice(dst_slice_start, dst_slice_dim); + } else { + result.slice(dst_slice_start, dst_slice_dim) = + tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim); + } dst_slice_start[2] += 1; } @@ -162,6 +175,8 @@ void test_cxx11_tensor_reverse() { CALL_SUBTEST(test_simple_reverse<ColMajor>()); CALL_SUBTEST(test_simple_reverse<RowMajor>()); - CALL_SUBTEST(test_expr_reverse<ColMajor>()); - CALL_SUBTEST(test_expr_reverse<RowMajor>()); + CALL_SUBTEST(test_expr_reverse<ColMajor>(true)); + CALL_SUBTEST(test_expr_reverse<RowMajor>(true)); + CALL_SUBTEST(test_expr_reverse<ColMajor>(false)); + CALL_SUBTEST(test_expr_reverse<RowMajor>(false)); } |