aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h12
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorScan.h22
-rw-r--r--unsupported/test/cxx11_tensor_scan.cpp24
3 files changed, 46 insertions, 12 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 8f3580ba7..87fa672f4 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -486,22 +486,22 @@ class TensorBase<Derived, ReadOnlyAccessors>
typedef TensorScanOp<internal::SumReducer<CoeffReturnType>, const Derived> TensorScanSumOp;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanSumOp
- cumsum(const Index& axis) const {
- return TensorScanSumOp(derived(), axis);
+ cumsum(const Index& axis, bool exclusive = false) const {
+ return TensorScanSumOp(derived(), axis, exclusive);
}
typedef TensorScanOp<internal::ProdReducer<CoeffReturnType>, const Derived> TensorScanProdOp;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanProdOp
- cumprod(const Index& axis) const {
- return TensorScanProdOp(derived(), axis);
+ cumprod(const Index& axis, bool exclusive = false) const {
+ return TensorScanProdOp(derived(), axis, exclusive);
}
template <typename Reducer>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanOp<Reducer, const Derived>
- scan(const Index& axis, const Reducer& reducer) const {
- return TensorScanOp<Reducer, const Derived>(derived(), axis, reducer);
+ scan(const Index& axis, const Reducer& reducer, bool exclusive = false) const {
+ return TensorScanOp<Reducer, const Derived>(derived(), axis, exclusive, reducer);
}
// Reductions.
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h
index 5207f6a8d..1aa196b84 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h
@@ -57,8 +57,8 @@ public:
typedef typename Eigen::internal::traits<TensorScanOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorScanOp(
- const XprType& expr, const Index& axis, const Op& op = Op())
- : m_expr(expr), m_axis(axis), m_accumulator(op) {}
+ const XprType& expr, const Index& axis, bool exclusive = false, const Op& op = Op())
+ : m_expr(expr), m_axis(axis), m_accumulator(op), m_exclusive(exclusive) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Index axis() const { return m_axis; }
@@ -66,11 +66,14 @@ public:
const XprType& expression() const { return m_expr; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Op accumulator() const { return m_accumulator; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool exclusive() const { return m_exclusive; }
protected:
typename XprType::Nested m_expr;
const Index m_axis;
const Op m_accumulator;
+ const bool m_exclusive;
};
// Eval as rvalue
@@ -99,6 +102,7 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
: m_impl(op.expression(), device),
m_device(device),
m_axis(op.axis()),
+ m_exclusive(op.exclusive()),
m_accumulator(op.accumulator()),
m_dimensions(m_impl.dimensions()),
m_size(m_dimensions[m_axis]),
@@ -168,6 +172,7 @@ protected:
TensorEvaluator<ArgType, Device> m_impl;
const Device& m_device;
const Index m_axis;
+ const bool m_exclusive;
Op m_accumulator;
const Dimensions& m_dimensions;
const Index& m_size;
@@ -176,7 +181,7 @@ protected:
// TODO(ibab) Parallelize this single-threaded implementation if desired
EIGEN_DEVICE_FUNC void accumulateTo(Scalar* data) {
- // We fix the index along the scan axis to 0 and perform an
+ // We fix the index along the scan axis to 0 and perform a
// scan per remaining entry. The iteration is split into two nested
// loops to avoid an integer division by keeping track of each idx1 and idx2.
for (Index idx1 = 0; idx1 < dimensions().TotalSize() / m_size; idx1 += m_stride) {
@@ -184,12 +189,17 @@ protected:
// Calculate the starting offset for the scan
Index offset = idx1 * m_size + idx2;
- // Compute the prefix sum along the axis, starting at the calculated offset
+ // Compute the scan along the axis, starting at the calculated offset
CoeffReturnType accum = m_accumulator.initialize();
for (Index idx3 = 0; idx3 < m_size; idx3++) {
Index curr = offset + idx3 * m_stride;
- m_accumulator.reduce(m_impl.coeff(curr), &accum);
- data[curr] = m_accumulator.finalize(accum);
+ if (m_exclusive) {
+ data[curr] = m_accumulator.finalize(accum);
+ m_accumulator.reduce(m_impl.coeff(curr), &accum);
+ } else {
+ m_accumulator.reduce(m_impl.coeff(curr), &accum);
+ data[curr] = m_accumulator.finalize(accum);
+ }
}
}
}
diff --git a/unsupported/test/cxx11_tensor_scan.cpp b/unsupported/test/cxx11_tensor_scan.cpp
index dbd3023d7..bafa6c96e 100644
--- a/unsupported/test/cxx11_tensor_scan.cpp
+++ b/unsupported/test/cxx11_tensor_scan.cpp
@@ -39,6 +39,30 @@ static void test_1d_scan()
}
template <int DataLayout, typename Type=float>
+static void test_1d_inclusive_scan()
+{
+ int size = 50;
+ Tensor<Type, 1, DataLayout> tensor(size);
+ tensor.setRandom();
+ Tensor<Type, 1, DataLayout> result = tensor.cumsum(0, true);
+
+ VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0));
+
+ float accum = 0;
+ for (int i = 0; i < size; i++) {
+ VERIFY_IS_EQUAL(result(i), accum);
+ accum += tensor(i);
+ }
+
+ accum = 1;
+ result = tensor.cumprod(0, true);
+ for (int i = 0; i < size; i++) {
+ VERIFY_IS_EQUAL(result(i), accum);
+ accum *= tensor(i);
+ }
+}
+
+template <int DataLayout, typename Type=float>
static void test_4d_scan()
{
int size = 5;