diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-20 11:41:22 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-20 11:41:22 -0700 |
commit | eaf4b98180d7606abba69133e39e23537ced79e5 (patch) | |
tree | a9f9adc1fc6dfa452e3dc1393371096b5c668c6b | |
parent | f5c1587e4e7d9e3e5a57deedff8d27b866a0a47b (diff) |
Added support for boolean reductions (ie 'and' & 'or' reductions)
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/README.md | 13 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 26 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 27 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_reduction.cpp | 17 |
4 files changed, 83 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/README.md b/unsupported/Eigen/CXX11/src/Tensor/README.md index 87e57cebb..407485090 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/README.md +++ b/unsupported/Eigen/CXX11/src/Tensor/README.md @@ -1149,6 +1149,19 @@ are the smallest of the reduced values. Reduce a tensor using the prod() operator. The resulting values are the product of the reduced values. +### <Operation> all(const Dimensions& new_dims) +### <Operation> all() +Reduce a tensor using the all() operator. Casts tensor to bool and then checks +whether all elements are true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + +### <Operation> any(const Dimensions& new_dims) +### <Operation> any() +Reduce a tensor using the any() operator. Casts tensor to bool and then checks +whether any element is true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + + ### <Operation> reduce(const Dimensions& new_dims, const Reducer& reducer) Reduce a tensor using a user-defined reduction operator. See ```SumReducer``` diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 477e4a174..c00f67950 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -363,6 +363,32 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); } + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::AndReducer, const Dims, const TensorConversionOp<bool, const Derived> > + all(const Dims& dims) const { + return cast<bool>().reduce(dims, internal::AndReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + all() const { + DimensionList<Index, NumDimensions> in_dims; + return cast<bool>().reduce(in_dims, internal::AndReducer()); + } + + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::OrReducer, const Dims, const TensorConversionOp<bool, const Derived> > + any(const Dims& dims) const { + return cast<bool>().reduce(dims, internal::OrReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + any() const { + DimensionList<Index, NumDimensions> in_dims; + return cast<bool>().reduce(in_dims, internal::OrReducer()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index ed259399b..a98c6a2e3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -219,6 +219,33 @@ template <typename T> struct ProdReducer }; +struct AndReducer +{ + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum && t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + +struct OrReducer { + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum || t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return false; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + // Argmin/Argmax reducers template <typename T> struct ArgMaxTupleReducer { diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index b2c85a879..e8180c061 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -180,6 +180,23 @@ static void test_simple_reductions() { VERIFY_IS_APPROX(mean1(0), mean2(0)); } + + { + Tensor<int, 1> ints(10); + std::iota(ints.data(), ints.data() + ints.dimension(0), 0); + + TensorFixedSize<bool, Sizes<1> > all; + all = ints.all(); + VERIFY(!all(0)); + all = (ints >= ints.constant(0)).all(); + VERIFY(all(0)); + + TensorFixedSize<bool, Sizes<1> > any; + any = (ints > ints.constant(10)).any(); + VERIFY(!any(0)); + any = (ints < ints.constant(1)).any(); + VERIFY(any(0)); + } } template <int DataLayout> |