diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 11:30:36 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 11:30:36 -0700 |
commit | 670c71d906a4f0adc7edf266c996183ae8e4a2cc (patch) | |
tree | aa4ff719bb30b3a97f4b106dfac79e8104dbba10 /unsupported/Eigen | |
parent | d8098ee7d5683f09a11cbb2e72edff41e5d9768f (diff) |
Express the full reduction operations (such as sum, max, min) using TensorDimensionList
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 35 |
1 files changed, 15 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 944dbf03f..30432fbc8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -301,11 +301,10 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::SumReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::SumReducer<CoeffReturnType>()); } - const TensorReductionOp<internal::SumReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + const TensorReductionOp<internal::SumReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> sum() const { - array<Index, NumDimensions> in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp<internal::SumReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::SumReducer<CoeffReturnType>()); + DimensionList<Index, NumDimensions> in_dims; + return TensorReductionOp<internal::SumReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::SumReducer<CoeffReturnType>()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -314,11 +313,10 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MeanReducer<CoeffReturnType>()); } - const TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + const TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> mean() const { - array<Index, NumDimensions> in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MeanReducer<CoeffReturnType>()); + DimensionList<Index, NumDimensions> in_dims; + return TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MeanReducer<CoeffReturnType>()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -327,11 +325,10 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::ProdReducer<CoeffReturnType>()); } - const TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + const TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> prod() const { - array<Index, NumDimensions> in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>()); + DimensionList<Index, NumDimensions> in_dims; + return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -340,11 +337,10 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>()); } - const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> maximum() const { - array<Index, NumDimensions> in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>()); + DimensionList<Index, NumDimensions> in_dims; + return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -353,11 +349,10 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>()); } - const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> minimum() const { - array<Index, NumDimensions> in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); + DimensionList<Index, NumDimensions> in_dims; + return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); } template <typename Reducer, typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE |