From c4d10e921f02409d771e960647beae0b6d26dc88 Mon Sep 17 00:00:00 2001 From: Igor Babuschkin Date: Tue, 14 Jun 2016 19:44:07 +0100 Subject: Implement exclusive scan option --- unsupported/Eigen/CXX11/src/Tensor/TensorScan.h | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorScan.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h index 5207f6a8d..1aa196b84 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h @@ -57,8 +57,8 @@ public: typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorScanOp( - const XprType& expr, const Index& axis, const Op& op = Op()) - : m_expr(expr), m_axis(axis), m_accumulator(op) {} + const XprType& expr, const Index& axis, bool exclusive = false, const Op& op = Op()) + : m_expr(expr), m_axis(axis), m_accumulator(op), m_exclusive(exclusive) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Index axis() const { return m_axis; } @@ -66,11 +66,14 @@ public: const XprType& expression() const { return m_expr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Op accumulator() const { return m_accumulator; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool exclusive() const { return m_exclusive; } protected: typename XprType::Nested m_expr; const Index m_axis; const Op m_accumulator; + const bool m_exclusive; }; // Eval as rvalue @@ -99,6 +102,7 @@ struct TensorEvaluator, Device> { : m_impl(op.expression(), device), m_device(device), m_axis(op.axis()), + m_exclusive(op.exclusive()), m_accumulator(op.accumulator()), m_dimensions(m_impl.dimensions()), m_size(m_dimensions[m_axis]), @@ -168,6 +172,7 @@ protected: TensorEvaluator m_impl; const Device& m_device; const Index m_axis; + const bool m_exclusive; Op m_accumulator; const Dimensions& m_dimensions; const Index& m_size; @@ -176,7 +181,7 @@ protected: // TODO(ibab) Parallelize this single-threaded implementation if desired EIGEN_DEVICE_FUNC void accumulateTo(Scalar* data) { - // We fix the index along the scan axis to 0 and perform an + // We fix the index along the scan axis to 0 and perform a // scan per remaining entry. The iteration is split into two nested // loops to avoid an integer division by keeping track of each idx1 and idx2. for (Index idx1 = 0; idx1 < dimensions().TotalSize() / m_size; idx1 += m_stride) { @@ -184,12 +189,17 @@ protected: // Calculate the starting offset for the scan Index offset = idx1 * m_size + idx2; - // Compute the prefix sum along the axis, starting at the calculated offset + // Compute the scan along the axis, starting at the calculated offset CoeffReturnType accum = m_accumulator.initialize(); for (Index idx3 = 0; idx3 < m_size; idx3++) { Index curr = offset + idx3 * m_stride; - m_accumulator.reduce(m_impl.coeff(curr), &accum); - data[curr] = m_accumulator.finalize(accum); + if (m_exclusive) { + data[curr] = m_accumulator.finalize(accum); + m_accumulator.reduce(m_impl.coeff(curr), &accum); + } else { + m_accumulator.reduce(m_impl.coeff(curr), &accum); + data[curr] = m_accumulator.finalize(accum); + } } } } -- cgit v1.2.3