aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-09-11 10:08:10 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-09-11 10:08:10 -0700
commit46f88fc454e78484ebdf9d58990d0489c1103cf4 (patch)
tree3f5702d5b0bd589963a25b6f3f5e49286f467a5f /unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
parent43fd42a33b484914ca92931ea63583b672c5e67b (diff)
Use numerically stable tree reduction in TensorReduction.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h68
1 files changed, 30 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index cd666c173..9fd276075 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -58,16 +58,15 @@ template<typename Reducer, typename Device>
struct reducer_traits {
enum {
Cost = 1,
- PacketAccess = false
+ PacketAccess = false,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
// Standard reduction functors
template <typename T> struct SumReducer
{
- static const bool PacketAccess = packet_traits<T>::HasAdd;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
internal::scalar_sum_op<T> sum_op;
*accum = sum_op(*accum, t);
@@ -103,16 +102,14 @@ template <typename T, typename Device>
struct reducer_traits<SumReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = PacketType<T, Device>::HasAdd
+ PacketAccess = PacketType<T, Device>::HasAdd,
+ IsStateful = false,
+ IsExactlyAssociative = NumTraits<T>::IsInteger
};
};
-
template <typename T> struct MeanReducer
{
- static const bool PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && !NumTraits<T>::IsInteger;
- static const bool IsStateful = true;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
MeanReducer() : scalarCount_(0), packetCount_(0) { }
@@ -161,7 +158,9 @@ template <typename T, typename Device>
struct reducer_traits<MeanReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = PacketType<T, Device>::HasAdd
+ PacketAccess = PacketType<T, Device>::HasAdd && !NumTraits<T>::IsInteger,
+ IsStateful = true,
+ IsExactlyAssociative = NumTraits<T>::IsInteger
};
};
@@ -194,9 +193,6 @@ struct MinMaxBottomValue<T, false, false> {
template <typename T> struct MaxReducer
{
- static const bool PacketAccess = packet_traits<T>::HasMax;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t > *accum) { *accum = t; }
}
@@ -228,16 +224,15 @@ template <typename T, typename Device>
struct reducer_traits<MaxReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = PacketType<T, Device>::HasMax
+ PacketAccess = PacketType<T, Device>::HasMax,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
template <typename T> struct MinReducer
{
- static const bool PacketAccess = packet_traits<T>::HasMin;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t < *accum) { *accum = t; }
}
@@ -269,16 +264,15 @@ template <typename T, typename Device>
struct reducer_traits<MinReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = PacketType<T, Device>::HasMin
+ PacketAccess = PacketType<T, Device>::HasMin,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
template <typename T> struct ProdReducer
{
- static const bool PacketAccess = packet_traits<T>::HasMul;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
internal::scalar_product_op<T> prod_op;
(*accum) = prod_op(*accum, t);
@@ -314,16 +308,15 @@ template <typename T, typename Device>
struct reducer_traits<ProdReducer<T>, Device> {
enum {
Cost = NumTraits<T>::MulCost,
- PacketAccess = PacketType<T, Device>::HasMul
+ PacketAccess = PacketType<T, Device>::HasMul,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
struct AndReducer
{
- static const bool PacketAccess = false;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const {
*accum = *accum && t;
}
@@ -339,15 +332,14 @@ template <typename Device>
struct reducer_traits<AndReducer, Device> {
enum {
Cost = 1,
- PacketAccess = false
+ PacketAccess = false,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
struct OrReducer {
- static const bool PacketAccess = false;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const {
*accum = *accum || t;
}
@@ -363,7 +355,9 @@ template <typename Device>
struct reducer_traits<OrReducer, Device> {
enum {
Cost = 1,
- PacketAccess = false
+ PacketAccess = false,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
@@ -371,9 +365,6 @@ struct reducer_traits<OrReducer, Device> {
// Argmin/Argmax reducers
template <typename T> struct ArgMaxTupleReducer
{
- static const bool PacketAccess = false;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t.second > accum->second) { *accum = t; }
}
@@ -389,16 +380,15 @@ template <typename T, typename Device>
struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = false
+ PacketAccess = false,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};
template <typename T> struct ArgMinTupleReducer
{
- static const bool PacketAccess = false;
- static const bool IsStateful = false;
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
if (t.second < accum->second) { *accum = t; }
}
@@ -414,7 +404,9 @@ template <typename T, typename Device>
struct reducer_traits<ArgMinTupleReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
- PacketAccess = false
+ PacketAccess = false,
+ IsStateful = false,
+ IsExactlyAssociative = true
};
};