diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-08-31 08:18:53 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-08-31 08:18:53 -0700 |
commit | f41831e445f3fdd9dc324561135b2a19eafd9a56 (patch) | |
tree | 045cb917d62685b342ce129e384e03e63c916898 /unsupported/Eigen/CXX11/src/Tensor | |
parent | 2ab603316af7c1bcf1d5e87d9ba50a2589b36e37 (diff) |
Added support for argmax/argmin
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h | 288 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 59 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 34 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h | 54 |
5 files changed, 437 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h new file mode 100644 index 000000000..ee3bf7fe3 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h @@ -0,0 +1,288 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com> +// Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H +#define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H + +namespace Eigen { +namespace internal { + +/** \class TensorIndexTuple + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor + Index Tuple class. + * + * + */ +template<typename XprType> +struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType> +{ + typedef traits<XprType> XprTraits; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; + typedef Tuple<Index, typename XprTraits::Scalar> Scalar; + typedef typename XprType::Nested Nested; + typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; +}; + +template<typename XprType> +struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense> +{ + typedef const TensorIndexTupleOp<XprType>& type; +}; + +template<typename XprType> +struct nested<TensorIndexTupleOp<XprType>, 1, + typename eval<TensorIndexTupleOp<XprType> >::type> +{ + typedef TensorIndexTupleOp<XprType> type; +}; + +} // end namespace internal + +template<typename XprType> +class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors> +{ + public: + typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested; + typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index; + typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr) + : m_xpr(expr) {} + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename XprType::Nested>::type& + expression() const { return m_xpr; } + + protected: + typename XprType::Nested m_xpr; +}; + +// Eval as rvalue +template<typename ArgType, typename Device> +struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> +{ + typedef TensorIndexTupleOp<ArgType> XprType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + + typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; + static const int NumDims = internal::array_size<Dimensions>::value; + + enum { + IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, + PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, + BlockAccess = false, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_impl(op.expression(), device) { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { + return m_impl.dimensions(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { + m_impl.evalSubExprsIfNeeded(NULL); + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const + { + return CoeffReturnType(index, m_impl.coeff(index)); + } + + EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } + + protected: + TensorEvaluator<ArgType, Device> m_impl; +}; + +namespace internal { + +/** \class TensorTupleIndex + * \ingroup CXX11_Tensor_Module + * + * \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>. + * + */ +template<typename ReduceOp, typename Dims, typename XprType> +struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType> +{ + typedef traits<XprType> XprTraits; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; + typedef Index Scalar; + typedef typename XprType::Nested Nested; + typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; +}; + +template<typename ReduceOp, typename Dims, typename XprType> +struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense> +{ + typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type; +}; + +template<typename ReduceOp, typename Dims, typename XprType> +struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1, + typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type> +{ + typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type; +}; + +} // end namespace internal + +template<typename ReduceOp, typename Dims, typename XprType> +class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> +{ + public: + typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested; + typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index; + typedef Index CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr, + const ReduceOp& reduce_op, + const int return_dim, + const Dims& reduce_dims) + : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {} + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename XprType::Nested>::type& + expression() const { return m_xpr; } + + EIGEN_DEVICE_FUNC + const ReduceOp& reduce_op() const { return m_reduce_op; } + + EIGEN_DEVICE_FUNC + const Dims& reduce_dims() const { return m_reduce_dims; } + + EIGEN_DEVICE_FUNC + int return_dim() const { return m_return_dim; } + + protected: + typename XprType::Nested m_xpr; + const ReduceOp m_reduce_op; + const int m_return_dim; + const Dims m_reduce_dims; +}; + +// Eval as rvalue +template<typename ReduceOp, typename Dims, typename ArgType, typename Device> +struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device> +{ + typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType; + typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions; + typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions; + static const int NumDims = internal::array_size<InputDimensions>::value; + typedef array<Index, NumDims> StrideDims; + + enum { + IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false, + PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false, + BlockAccess = false, + Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_orig_impl(op.expression(), device), + m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), + m_return_dim(op.return_dim()), + m_strides(gen_strides(m_orig_impl.dimensions())), + m_stride_mod(gen_stride_mod(m_orig_impl.dimensions())), + m_stride_div(gen_stride_div()) { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { + return m_impl.dimensions(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { + m_impl.evalSubExprsIfNeeded(NULL); + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + const TupleType v = m_impl.coeff(index); + return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div; + } + + EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } + + private: + EIGEN_DEVICE_FUNC StrideDims gen_strides(const InputDimensions& dims) { + StrideDims strides; + if (m_return_dim < 0) return strides; // Won't be using these. + eigen_assert(m_return_dim < NumDims && + "Asking to convert index to a dimension outside of the rank"); + + // Calculate m_stride_div and m_stride_mod, which are used to + // calculate the value of an index w.r.t. the m_return_dim. + if (Layout == static_cast<int>(ColMajor)) { + strides[0] = 1; + for (int i = 1; i < NumDims; ++i) { + strides[i] = strides[i-1] * dims[i-1]; + } + } else { + strides[NumDims-1] = 1; + for (int i = NumDims - 2; i >= 0; --i) { + strides[i] = strides[i+1] * dims[i+1]; + } + } + return strides; + } + + EIGEN_DEVICE_FUNC Index gen_stride_mod(const InputDimensions& dims) { + if (Layout == static_cast<int>(ColMajor)) { + return (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : dims.TotalSize(); + } else { + return (m_return_dim > 0) ? m_strides[m_return_dim - 1] : dims.TotalSize(); + } + } + + EIGEN_DEVICE_FUNC Index gen_stride_div() { + return m_strides[m_return_dim]; + } + + protected: + TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl; + TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl; + const int m_return_dim; + const StrideDims m_strides; + const Index m_stride_mod; + const Index m_stride_div; +}; + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 0e5e4b426..477e4a174 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -363,6 +363,58 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorTupleReducerOp< + internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, NumDimensions>, const Derived> + argmax() const { + array<Index, NumDimensions> in_dims; + for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d; + return TensorTupleReducerOp< + internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, NumDimensions>, + const Derived>(derived(), internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >(), -1, in_dims); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorTupleReducerOp< + internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, NumDimensions>, const Derived> + argmin() const { + array<Index, NumDimensions> in_dims; + for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d; + return TensorTupleReducerOp< + internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, NumDimensions>, + const Derived>(derived(), internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >(), -1, in_dims); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorTupleReducerOp< + internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, 1>, const Derived> + argmax(const int return_dim) const { + array<Index, 1> in_dims; + in_dims[0] = return_dim; + return TensorTupleReducerOp< + internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, 1>, + const Derived>(derived(), internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >(), return_dim, in_dims); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorTupleReducerOp< + internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, 1>, const Derived> + argmin(const int return_dim) const { + array<Index, 1> in_dims; + in_dims[0] = return_dim; + return TensorTupleReducerOp< + internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, + const array<Index, 1>, + const Derived>(derived(), internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >(), return_dim, in_dims); + } + template <typename Reducer, typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp<Reducer, const Dims, const Derived> reduce(const Dims& dims, const Reducer& reducer) const { @@ -483,6 +535,13 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorInflationOp<const Strides, const Derived>(derived(), strides); } + // Returns a tensor containing index/value tuples + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorIndexTupleOp<const Derived> + index_tuples() const { + return TensorIndexTupleOp<const Derived>(derived()); + } + // Support for custom unary and binary operations template <typename CustomUnaryFunc> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 17b0e6153..c22444e6f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -23,6 +23,8 @@ template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp; template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp; template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp; template<typename Op, typename Dims, typename XprType> class TensorReductionOp; +template<typename XprType> class TensorIndexTupleOp; +template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp; template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp; template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp; template<typename TargetType, typename XprType> class TensorConversionOp; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index d9061c216..ed259399b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -219,6 +219,40 @@ template <typename T> struct ProdReducer }; +// Argmin/Argmax reducers +template <typename T> struct ArgMaxTupleReducer +{ + static const bool PacketAccess = false; + static const bool IsStateful = false; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { + if (t.second > accum->second) { *accum = t; } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return T(0, NumTraits<typename T::second_type>::lowest()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T& accum) const { + return accum; + } +}; + +template <typename T> struct ArgMinTupleReducer +{ + static const bool PacketAccess = false; + static const bool IsStateful = false; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const { + if (t.second < accum->second) { *accum = t; } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return T(0, NumTraits<typename T::second_type>::highest()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T& accum) const { + return accum; + } +}; + + // Random number generation namespace { #ifdef __CUDA_ARCH__ diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h index 78feb85cd..7dfa04760 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h @@ -31,6 +31,60 @@ template <> struct max_n_1<0> { static const size_t size = 1; }; + + + +#if defined(EIGEN_HAS_CONSTEXPR) +#define EIGEN_CONSTEXPR constexpr +#else +#define EIGEN_CONSTEXPR +#endif + +// Tuple mimics std::pair but works on e.g. nvcc. +template <typename U, typename V> struct Tuple { + public: + U first; + V second; + + typedef U first_type; + typedef V second_type; + + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Tuple() : first(), second() {} + + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Tuple(const U& f, const V& s) : first(f), second(s) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Tuple& operator= (const Tuple& rhs) { + if (&rhs == this) return *this; + first = rhs.first; + second = rhs.second; + return *this; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void swap(Tuple& rhs) { + using numext::swap; + swap(first, rhs.first); + swap(second, rhs.second); + } +}; + +template <typename U, typename V> +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +bool operator==(const Tuple<U, V>& x, const Tuple<U, V>& y) { + return (x.first == y.first && x.second == y.second); +} + +template <typename U, typename V> +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) { + return !(x == y); +} + +#undef EIGEN_CONSTEXPR + } // namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_META_H |