aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-26 11:13:42 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-26 11:13:42 -0800
commit57154fdb32a4853bff458f8014b037d5e41b9858 (patch)
treebe601f42c4691458bc9d52cac259b755e9504e9a
parent2fffe69b1be0b4448c5105edf4aeac22937ae5dc (diff)
Can now use the tensor 'reverse' operation as a lvalue
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h5
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h91
-rw-r--r--unsupported/test/cxx11_tensor_reverse.cpp35
3 files changed, 109 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index cfcf18e8e..201b0fc9e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -549,6 +549,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
chip(const Index offset, const Index dim) const {
return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim);
}
+ template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ TensorReverseOp<const ReverseDimensions, Derived>
+ reverse(const ReverseDimensions& rev) const {
+ return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev);
+ }
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorShufflingOp<const Shuffle, Derived>
shuffle(const Shuffle& shuffle) const {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
index ad21e966b..16bef2ad3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
@@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1,
} // end namespace internal
-
-
-
template<typename ReverseDimensions, typename XprType>
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
- XprType>, ReadOnlyAccessors>
+ XprType>, WriteAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
@@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
StorageKind;
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr,
- const ReverseDimensions& reverse_dims)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(
+ const XprType& expr, const ReverseDimensions& reverse_dims)
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
EIGEN_DEVICE_FUNC
@@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
const typename internal::remove_all<typename XprType::Nested>::type&
expression() const { return m_xpr; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other)
+ {
+ typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other)
+ {
+ typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
protected:
typename XprType::Nested m_xpr;
const ReverseDimensions m_reverse_dims;
};
-
// Eval as rvalue
template<typename ReverseDimensions, typename ArgType, typename Device>
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
@@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
m_impl.cleanup();
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
- {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex(
+ Index index) const {
eigen_assert(index < dimensions().TotalSize());
Index inputIndex = 0;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
@@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
- return m_impl.coeff(inputIndex);
} else {
for (int i = 0; i < NumDims - 1; ++i) {
Index idx = index / m_strides[i];
@@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
} else {
inputIndex += index;
}
- return m_impl.coeff(inputIndex);
}
+ return inputIndex;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
+ Index index) const {
+ return m_impl.coeff(reverseIndex(index));
}
template<int LoadMode>
@@ -199,9 +218,57 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
ReverseDimensions m_reverse;
};
+// Eval as lvalue
+
+template <typename ReverseDimensions, typename ArgType, typename Device>
+struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device>
+ : public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
+ Device> {
+ typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
+ Device> Base;
+ typedef TensorReverseOp<ReverseDimensions, ArgType> XprType;
+ typedef typename XprType::Index Index;
+ static const int NumDims = internal::array_size<ReverseDimensions>::value;
+ typedef DSizes<Index, NumDims> Dimensions;
+
+ enum {
+ IsAligned = false,
+ PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ };
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
+ const Device& device)
+ : Base(op, device) {}
+
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Dimensions& dimensions() const { return this->m_dimensions; }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
+ return this->m_impl.coeffRef(this->reverseIndex(index));
+ }
+
+ template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void writePacket(Index index, const PacketReturnType& x) {
+ const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(index+packetSize-1 < dimensions().TotalSize());
+ // This code is pilfered from TensorMorphing.h
+ EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
+ internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
+ for (int i = 0; i < packetSize; ++i) {
+ this->coeffRef(index+i) = values[i];
+ }
+ }
+
+};
-} // end namespace Eigen
+} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_REVERSE_H
diff --git a/unsupported/test/cxx11_tensor_reverse.cpp b/unsupported/test/cxx11_tensor_reverse.cpp
index 4c0be35da..f96c21fa3 100644
--- a/unsupported/test/cxx11_tensor_reverse.cpp
+++ b/unsupported/test/cxx11_tensor_reverse.cpp
@@ -94,7 +94,7 @@ static void test_simple_reverse()
template <int DataLayout>
-static void test_expr_reverse()
+static void test_expr_reverse(bool LValue)
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
@@ -105,9 +105,12 @@ static void test_expr_reverse()
dim_rev[2] = false;
dim_rev[3] = true;
-
- Tensor<float, 4, DataLayout> expected;
- expected = tensor.reverse(dim_rev);
+ Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
+ if (LValue) {
+ expected.reverse(dim_rev) = tensor;
+ } else {
+ expected = tensor.reverse(dim_rev);
+ }
Tensor<float, 4, DataLayout> result(2,3,5,7);
@@ -117,8 +120,13 @@ static void test_expr_reverse()
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
for (int i = 0; i < 5; ++i) {
- result.slice(dst_slice_start, dst_slice_dim) =
- tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
+ if (LValue) {
+ result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
+ tensor.slice(src_slice_start, src_slice_dim);
+ } else {
+ result.slice(dst_slice_start, dst_slice_dim) =
+ tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
+ }
src_slice_start[2] += 1;
dst_slice_start[2] += 1;
}
@@ -141,8 +149,13 @@ static void test_expr_reverse()
dst_slice_start[2] = 0;
result.setRandom();
for (int i = 0; i < 5; ++i) {
- result.slice(dst_slice_start, dst_slice_dim) =
- tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
+ if (LValue) {
+ result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
+ tensor.slice(dst_slice_start, dst_slice_dim);
+ } else {
+ result.slice(dst_slice_start, dst_slice_dim) =
+ tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
+ }
dst_slice_start[2] += 1;
}
@@ -162,6 +175,8 @@ void test_cxx11_tensor_reverse()
{
CALL_SUBTEST(test_simple_reverse<ColMajor>());
CALL_SUBTEST(test_simple_reverse<RowMajor>());
- CALL_SUBTEST(test_expr_reverse<ColMajor>());
- CALL_SUBTEST(test_expr_reverse<RowMajor>());
+ CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
+ CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
+ CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
+ CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
}