diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-09-11 10:08:10 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-09-11 10:08:10 -0700 |
commit | 46f88fc454e78484ebdf9d58990d0489c1103cf4 (patch) | |
tree | 3f5702d5b0bd589963a25b6f3f5e49286f467a5f /unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | |
parent | 43fd42a33b484914ca92931ea63583b672c5e67b (diff) |
Use numerically stable tree reduction in TensorReduction.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 68 |
1 files changed, 30 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index cd666c173..9fd276075 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -58,16 +58,15 @@ template<typename Reducer, typename Device> struct reducer_traits { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; // Standard reduction functors template <typename T> struct SumReducer { - static const bool PacketAccess = packet_traits<T>::HasAdd; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { internal::scalar_sum_op<T> sum_op; *accum = sum_op(*accum, t); @@ -103,16 +102,14 @@ template <typename T, typename Device> struct reducer_traits<SumReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = PacketType<T, Device>::HasAdd + PacketAccess = PacketType<T, Device>::HasAdd, + IsStateful = false, + IsExactlyAssociative = NumTraits<T>::IsInteger }; }; - template <typename T> struct MeanReducer { - static const bool PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && !NumTraits<T>::IsInteger; - static const bool IsStateful = true; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MeanReducer() : scalarCount_(0), packetCount_(0) { } @@ -161,7 +158,9 @@ template <typename T, typename Device> struct reducer_traits<MeanReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = PacketType<T, Device>::HasAdd + PacketAccess = PacketType<T, Device>::HasAdd && !NumTraits<T>::IsInteger, + IsStateful = true, + IsExactlyAssociative = NumTraits<T>::IsInteger }; }; @@ -194,9 +193,6 @@ struct MinMaxBottomValue<T, false, false> { template <typename T> struct MaxReducer { - static const bool PacketAccess = packet_traits<T>::HasMax; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t > *accum) { *accum = t; } } @@ -228,16 +224,15 @@ template <typename T, typename Device> struct reducer_traits<MaxReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = PacketType<T, Device>::HasMax + PacketAccess = PacketType<T, Device>::HasMax, + IsStateful = false, + IsExactlyAssociative = true }; }; template <typename T> struct MinReducer { - static const bool PacketAccess = packet_traits<T>::HasMin; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t < *accum) { *accum = t; } } @@ -269,16 +264,15 @@ template <typename T, typename Device> struct reducer_traits<MinReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = PacketType<T, Device>::HasMin + PacketAccess = PacketType<T, Device>::HasMin, + IsStateful = false, + IsExactlyAssociative = true }; }; template <typename T> struct ProdReducer { - static const bool PacketAccess = packet_traits<T>::HasMul; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { internal::scalar_product_op<T> prod_op; (*accum) = prod_op(*accum, t); @@ -314,16 +308,15 @@ template <typename T, typename Device> struct reducer_traits<ProdReducer<T>, Device> { enum { Cost = NumTraits<T>::MulCost, - PacketAccess = PacketType<T, Device>::HasMul + PacketAccess = PacketType<T, Device>::HasMul, + IsStateful = false, + IsExactlyAssociative = true }; }; struct AndReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { *accum = *accum && t; } @@ -339,15 +332,14 @@ template <typename Device> struct reducer_traits<AndReducer, Device> { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; struct OrReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { *accum = *accum || t; } @@ -363,7 +355,9 @@ template <typename Device> struct reducer_traits<OrReducer, Device> { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; @@ -371,9 +365,6 @@ struct reducer_traits<OrReducer, Device> { // Argmin/Argmax reducers template <typename T> struct ArgMaxTupleReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t.second > accum->second) { *accum = t; } } @@ -389,16 +380,15 @@ template <typename T, typename Device> struct reducer_traits<ArgMaxTupleReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; template <typename T> struct ArgMinTupleReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const { if (t.second < accum->second) { *accum = t; } } @@ -414,7 +404,9 @@ template <typename T, typename Device> struct reducer_traits<ArgMinTupleReducer<T>, Device> { enum { Cost = NumTraits<T>::AddCost, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; |