diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 3dd32e9d1..bf52e490f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -84,6 +84,14 @@ struct functor_traits<scalar_sigmoid_op<T> > { }; +template<typename Reducer, typename Device> +struct reducer_traits { + enum { + Cost = 1, + PacketAccess = false + }; +}; + // Standard reduction functors template <typename T> struct SumReducer { @@ -119,6 +127,15 @@ template <typename T> struct SumReducer } }; +template <typename T, typename Device> +struct reducer_traits<SumReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = PacketType<T, Device>::type::HasAdd + }; +}; + + template <typename T> struct MeanReducer { static const bool PacketAccess = packet_traits<T>::HasAdd && !NumTraits<T>::IsInteger; @@ -162,6 +179,15 @@ template <typename T> struct MeanReducer DenseIndex packetCount_; }; +template <typename T, typename Device> +struct reducer_traits<MeanReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = PacketType<T, Device>::type::HasAdd + }; +}; + + template <typename T> struct MaxReducer { static const bool PacketAccess = packet_traits<T>::HasMax; @@ -195,6 +221,15 @@ template <typename T> struct MaxReducer } }; +template <typename T, typename Device> +struct reducer_traits<MaxReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = PacketType<T, Device>::type::HasMax + }; +}; + + template <typename T> struct MinReducer { static const bool PacketAccess = packet_traits<T>::HasMin; @@ -228,6 +263,14 @@ template <typename T> struct MinReducer } }; +template <typename T, typename Device> +struct reducer_traits<MinReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = PacketType<T, Device>::type::HasMin + }; +}; + template <typename T> struct ProdReducer { @@ -263,6 +306,14 @@ template <typename T> struct ProdReducer } }; +template <typename T, typename Device> +struct reducer_traits<ProdReducer<T>, Device> { + enum { + Cost = NumTraits<T>::MulCost, + PacketAccess = PacketType<T, Device>::type::HasMul + }; +}; + struct AndReducer { @@ -280,6 +331,15 @@ struct AndReducer } }; +template <typename Device> +struct reducer_traits<AndReducer, Device> { + enum { + Cost = 1, + PacketAccess = false + }; +}; + + struct OrReducer { static const bool PacketAccess = false; static const bool IsStateful = false; @@ -295,6 +355,15 @@ struct OrReducer { } }; +template <typename Device> +struct reducer_traits<OrReducer, Device> { + enum { + Cost = 1, + PacketAccess = false + }; +}; + + // Argmin/Argmax reducers template <typename T> struct ArgMaxTupleReducer { @@ -312,6 +381,15 @@ template <typename T> struct ArgMaxTupleReducer } }; +template <typename T, typename Device> +struct reducer_traits<ArgMaxTupleReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = false + }; +}; + + template <typename T> struct ArgMinTupleReducer { static const bool PacketAccess = false; @@ -328,6 +406,14 @@ template <typename T> struct ArgMinTupleReducer } }; +template <typename T, typename Device> +struct reducer_traits<ArgMinTupleReducer<T>, Device> { + enum { + Cost = NumTraits<T>::AddCost, + PacketAccess = false + }; +}; + // Random number generation namespace { |