aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-26 11:13:42 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-26 11:13:42 -0800
commit57154fdb32a4853bff458f8014b037d5e41b9858 (patch)
treebe601f42c4691458bc9d52cac259b755e9504e9a /unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
parent2fffe69b1be0b4448c5105edf4aeac22937ae5dc (diff)
Can now use the tensor 'reverse' operation as a lvalue
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h91
1 files changed, 79 insertions, 12 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
index ad21e966b..16bef2ad3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
@@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1,
} // end namespace internal
-
-
-
template<typename ReverseDimensions, typename XprType>
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
- XprType>, ReadOnlyAccessors>
+ XprType>, WriteAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
@@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
StorageKind;
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr,
- const ReverseDimensions& reverse_dims)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(
+ const XprType& expr, const ReverseDimensions& reverse_dims)
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
EIGEN_DEVICE_FUNC
@@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
const typename internal::remove_all<typename XprType::Nested>::type&
expression() const { return m_xpr; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other)
+ {
+ typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other)
+ {
+ typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
protected:
typename XprType::Nested m_xpr;
const ReverseDimensions m_reverse_dims;
};
-
// Eval as rvalue
template<typename ReverseDimensions, typename ArgType, typename Device>
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
@@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
m_impl.cleanup();
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
- {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex(
+ Index index) const {
eigen_assert(index < dimensions().TotalSize());
Index inputIndex = 0;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
@@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
- return m_impl.coeff(inputIndex);
} else {
for (int i = 0; i < NumDims - 1; ++i) {
Index idx = index / m_strides[i];
@@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
- return m_impl.coeff(inputIndex);
}
+ return inputIndex;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
+ Index index) const {
+ return m_impl.coeff(reverseIndex(index));
}
template<int LoadMode>
@@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
ReverseDimensions m_reverse;
};
+// Eval as lvalue
+
+template <typename ReverseDimensions, typename ArgType, typename Device>
+struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device>
+ : public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
+ Device> {
+ typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
+ Device> Base;
+ typedef TensorReverseOp<ReverseDimensions, ArgType> XprType;
+ typedef typename XprType::Index Index;
+ static const int NumDims = internal::array_size<ReverseDimensions>::value;
+ typedef DSizes<Index, NumDims> Dimensions;
+
+ enum {
+ IsAligned = false,
+ PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ };
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
+ const Device& device)
+ : Base(op, device) {}
+
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Dimensions& dimensions() const { return this->m_dimensions; }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
+ return this->m_impl.coeffRef(this->reverseIndex(index));
+ }
+
+ template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void writePacket(Index index, const PacketReturnType& x) {
+ const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(index+packetSize-1 < dimensions().TotalSize());
+ // This code is pilfered from TensorMorphing.h
+ EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
+ internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
+ for (int i = 0; i < packetSize; ++i) {
+ this->coeffRef(index+i) = values[i];
+ }
+ }
+
+};
-} // end namespace Eigen
+} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H