diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-20 11:41:22 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-20 11:41:22 -0700 |
commit | eaf4b98180d7606abba69133e39e23537ced79e5 (patch) | |
tree | a9f9adc1fc6dfa452e3dc1393371096b5c668c6b /unsupported/Eigen | |
parent | f5c1587e4e7d9e3e5a57deedff8d27b866a0a47b (diff) |
Added support for boolean reductions (ie 'and' & 'or' reductions)
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/README.md | 13 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 26 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 27 |
3 files changed, 66 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/README.md b/unsupported/Eigen/CXX11/src/Tensor/README.md index 87e57cebb..407485090 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/README.md +++ b/unsupported/Eigen/CXX11/src/Tensor/README.md @@ -1149,6 +1149,19 @@ are the smallest of the reduced values. Reduce a tensor using the prod() operator. The resulting values are the product of the reduced values. +### <Operation> all(const Dimensions& new_dims) +### <Operation> all() +Reduce a tensor using the all() operator. Casts tensor to bool and then checks +whether all elements are true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + +### <Operation> any(const Dimensions& new_dims) +### <Operation> any() +Reduce a tensor using the any() operator. Casts tensor to bool and then checks +whether any element is true. Runs through all elements rather than +short-circuiting, so may be significantly inefficient. + + ### <Operation> reduce(const Dimensions& new_dims, const Reducer& reducer) Reduce a tensor using a user-defined reduction operator. See ```SumReducer``` diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 477e4a174..c00f67950 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -363,6 +363,32 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); } + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::AndReducer, const Dims, const TensorConversionOp<bool, const Derived> > + all(const Dims& dims) const { + return cast<bool>().reduce(dims, internal::AndReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + all() const { + DimensionList<Index, NumDimensions> in_dims; + return cast<bool>().reduce(in_dims, internal::AndReducer()); + } + + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::OrReducer, const Dims, const TensorConversionOp<bool, const Derived> > + any(const Dims& dims) const { + return cast<bool>().reduce(dims, internal::OrReducer()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + any() const { + DimensionList<Index, NumDimensions> in_dims; + return cast<bool>().reduce(in_dims, internal::OrReducer()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index ed259399b..a98c6a2e3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -219,6 +219,33 @@ template <typename T> struct ProdReducer }; +struct AndReducer +{ + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum && t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + +struct OrReducer { + static const bool PacketAccess = false; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { + *accum = *accum || t; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const { + return false; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const { + return accum; + } +}; + // Argmin/Argmax reducers template <typename T> struct ArgMaxTupleReducer { |