From 46f88fc454e78484ebdf9d58990d0489c1103cf4 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 11 Sep 2018 10:08:10 -0700 Subject: Use numerically stable tree reduction in TensorReduction. --- .../Eigen/CXX11/src/Tensor/TensorFunctors.h | 68 ++++++++++------------ 1 file changed, 30 insertions(+), 38 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index cd666c173..9fd276075 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -58,16 +58,15 @@ template struct reducer_traits { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; // Standard reduction functors template struct SumReducer { - static const bool PacketAccess = packet_traits::HasAdd; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { internal::scalar_sum_op sum_op; *accum = sum_op(*accum, t); @@ -103,16 +102,14 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = PacketType::HasAdd + PacketAccess = PacketType::HasAdd, + IsStateful = false, + IsExactlyAssociative = NumTraits::IsInteger }; }; - template struct MeanReducer { - static const bool PacketAccess = packet_traits::HasAdd && packet_traits::HasDiv && !NumTraits::IsInteger; - static const bool IsStateful = true; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MeanReducer() : scalarCount_(0), packetCount_(0) { } @@ -161,7 +158,9 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = PacketType::HasAdd + PacketAccess = PacketType::HasAdd && !NumTraits::IsInteger, + IsStateful = true, + IsExactlyAssociative = NumTraits::IsInteger }; }; @@ -194,9 +193,6 @@ struct MinMaxBottomValue { template struct MaxReducer { - static const bool PacketAccess = packet_traits::HasMax; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t > *accum) { *accum = t; } } @@ -228,16 +224,15 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = PacketType::HasMax + PacketAccess = PacketType::HasMax, + IsStateful = false, + IsExactlyAssociative = true }; }; template struct MinReducer { - static const bool PacketAccess = packet_traits::HasMin; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t < *accum) { *accum = t; } } @@ -269,16 +264,15 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = PacketType::HasMin + PacketAccess = PacketType::HasMin, + IsStateful = false, + IsExactlyAssociative = true }; }; template struct ProdReducer { - static const bool PacketAccess = packet_traits::HasMul; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { internal::scalar_product_op prod_op; (*accum) = prod_op(*accum, t); @@ -314,16 +308,15 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::MulCost, - PacketAccess = PacketType::HasMul + PacketAccess = PacketType::HasMul, + IsStateful = false, + IsExactlyAssociative = true }; }; struct AndReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { *accum = *accum && t; } @@ -339,15 +332,14 @@ template struct reducer_traits { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; struct OrReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const { *accum = *accum || t; } @@ -363,7 +355,9 @@ template struct reducer_traits { enum { Cost = 1, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; @@ -371,9 +365,6 @@ struct reducer_traits { // Argmin/Argmax reducers template struct ArgMaxTupleReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t.second > accum->second) { *accum = t; } } @@ -389,16 +380,15 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; template struct ArgMinTupleReducer { - static const bool PacketAccess = false; - static const bool IsStateful = false; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const { if (t.second < accum->second) { *accum = t; } } @@ -414,7 +404,9 @@ template struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, - PacketAccess = false + PacketAccess = false, + IsStateful = false, + IsExactlyAssociative = true }; }; -- cgit v1.2.3