From eaf4b98180d7606abba69133e39e23537ced79e5 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 20 Oct 2015 11:41:22 -0700 Subject: Added support for boolean reductions (ie 'and' & 'or' reductions) --- unsupported/Eigen/CXX11/src/Tensor/README.md | 13 +++++++++++ unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 26 +++++++++++++++++++++ .../Eigen/CXX11/src/Tensor/TensorFunctors.h | 27 ++++++++++++++++++++++ unsupported/test/cxx11_tensor_reduction.cpp | 17 ++++++++++++++ 4 files changed, 83 insertions(+) diff --git a/unsupported/Eigen/CXX11/src/Tensor/README.md b/unsupported/Eigen/CXX11/src/Tensor/README.md index 87e57cebb..407485090 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/README.md +++ b/unsupported/Eigen/CXX11/src/Tensor/README.md @@ -1149,6 +1149,19 @@ are the smallest of the reduced values. Reduce a tensor using the prod() operator. The resulting values are the product of the reduced values. +### <Operation> all(const Dimensions& new_dims) +### <Operation> all() +Reduce a tensor using the all() operator. Casts tensor to bool and then checks +whether all elements are true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + +### <Operation> any(const Dimensions& new_dims) +### <Operation> any() +Reduce a tensor using the any() operator. Casts tensor to bool and then checks +whether any element is true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + + ### <Operation> reduce(const Dimensions& new_dims, const Reducer& reducer) Reduce a tensor using a user-defined reduction operator. See ```SumReducer``` diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 477e4a174..c00f67950 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -363,6 +363,32 @@ class TensorBase return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MinReducer()); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp > + all(const Dims& dims) const { + return cast().reduce(dims, internal::AndReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp, const TensorConversionOp > + all() const { + DimensionList in_dims; + return cast().reduce(in_dims, internal::AndReducer()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp > + any(const Dims& dims) const { + return cast().reduce(dims, internal::OrReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp, const TensorConversionOp > + any() const { + DimensionList in_dims; + return cast().reduce(in_dims, internal::OrReducer()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMaxTupleReducer >, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index ed259399b..a98c6a2e3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -219,6 +219,33 @@ template struct ProdReducer }; +struct AndReducer +{ + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum && t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + +struct OrReducer { + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum || t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return false; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + // Argmin/Argmax reducers template struct ArgMaxTupleReducer { diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index b2c85a879..e8180c061 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -180,6 +180,23 @@ static void test_simple_reductions() { VERIFY_IS_APPROX(mean1(0), mean2(0)); } + + { + Tensor ints(10); + std::iota(ints.data(), ints.data() + ints.dimension(0), 0); + + TensorFixedSize > all; + all = ints.all(); + VERIFY(!all(0)); + all = (ints >= ints.constant(0)).all(); + VERIFY(all(0)); + + TensorFixedSize > any; + any = (ints > ints.constant(10)).any(); + VERIFY(!any(0)); + any = (ints < ints.constant(1)).any(); + VERIFY(any(0)); + } } template -- cgit v1.2.3