diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 10:49:55 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 10:49:55 -0700 |
commit | 3625734bc8e63b4ccb4c190d170fd5d29fba5b54 (patch) | |
tree | 802298cb02cda49d1d786b8b850af633d29fbe16 /unsupported/Eigen/CXX11/src | |
parent | 392a30db824161a0466f7970e12e5a48ed05ebd9 (diff) |
Moved some utilities to TensorMeta.h to make it easier to reuse them accross several tensor operations.
Created the TensorDimensionList class to encode the list of all the dimensions of a tensor of rank n. This could be done using TensorIndexList, however TensorIndexList require cxx11 which isn't yet supported as widely as we'd like.
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
5 files changed, 283 insertions, 33 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 270383020..fd2f3abc4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -364,14 +364,6 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co }; -template <size_t n> struct max_n_1 { - static const size_t size = n; -}; -template <> struct max_n_1<0> { - static const size_t size = 1; -}; - - template<typename Dimensions, typename LhsXprType, typename RhsXprType> struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > { @@ -459,19 +451,6 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp }; -template<bool cond> struct Cond {}; - -template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE -const T1& choose(Cond<true>, const T1& first, const T2&) { - return first; -} - -template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE -const T2& choose(Cond<false>, const T1&, const T2& second) { - return second; -} - - template<typename Derived> struct TensorContractionEvaluatorBase { @@ -508,13 +487,13 @@ struct TensorContractionEvaluatorBase static const int RDims = internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; static const unsigned int ContractDims = internal::array_size<Indices>::value; - static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef array<Index, LDims> left_dim_mapper_t; typedef array<Index, RDims> right_dim_mapper_t; typedef array<Index, ContractDims> contract_t; - typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; typedef DSizes<Index, NumDims> Dimensions; @@ -869,10 +848,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef array<Index, RDims> right_dim_mapper_t; typedef array<Index, ContractDims> contract_t; - typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; - static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; // Could we use NumDimensions here? typedef DSizes<Index, NumDims> Dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h index 588770bb4..f6bd949bd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -1241,10 +1241,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef array<Index, RDims> right_dim_mapper_t; typedef array<Index, ContractDims> contract_t; - typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; - static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef DSizes<Index, NumDims> Dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index ed87d3100..57030229d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -93,10 +93,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typedef array<Index, RDims> right_dim_mapper_t; typedef array<Index, ContractDims> contract_t; - typedef array<Index, internal::max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, internal::max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; + typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; - static const int NumDims = internal::max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; typedef DSizes<Index, NumDims> Dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h new file mode 100644 index 000000000..19e922f92 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h @@ -0,0 +1,235 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSION_LIST_H +#define EIGEN_CXX11_TENSOR_TENSOR_DIMENSION_LIST_H + +namespace Eigen { + +/** \internal + * + * \class TensorDimensionList + * \ingroup CXX11_Tensor_Module + * + * \brief Special case of tensor index list used to list all the dimensions of a tensor of rank n. + * + * \sa Tensor + */ + +template <typename Index, std::size_t Rank> struct DimensionList { + const Index operator[] (const Index i) const { return i; } +}; + +namespace internal { + +template<typename Index, std::size_t Rank> struct array_size<DimensionList<Index, Rank> > { + static const size_t value = Rank; +}; +template<typename Index, std::size_t Rank> struct array_size<const DimensionList<Index, Rank> > { + static const size_t value = Rank; +}; + +template<DenseIndex n, typename Index, std::size_t Rank> const Index array_get(DimensionList<Index, Rank>& a) { + return n; +} +template<DenseIndex n, typename Index, std::size_t Rank> const Index array_get(const DimensionList<Index, Rank>& a) { + return n; +} + + +#if defined(EIGEN_HAS_CONSTEXPR) +template <typename Index, std::size_t Rank> +struct index_known_statically<DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex) const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct index_known_statically<const DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex) const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct all_indices_known_statically<DimensionList<Index, Rank> > { + constexpr bool operator() () const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct all_indices_known_statically<const DimensionList<Index, Rank> > { + constexpr bool operator() () const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct indices_statically_known_to_increase<DimensionList<Index, Rank> > { + constexpr bool operator() () const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct indices_statically_known_to_increase<const DimensionList<Index, Rank> > { + constexpr bool operator() () const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_eq<DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i == value; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_eq<const DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i == value; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_ne<DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i != value; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_ne<const DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i != value; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_gt<DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i > value; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_gt<const DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i > value; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_lt<DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i < value; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_lt<const DimensionList<Index, Rank> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return i < value; + } +}; + +#else +template <typename Index, std::size_t Rank> +struct index_known_statically<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex) const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct index_known_statically<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex) const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct all_indices_known_statically<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() () const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct all_indices_known_statically<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() () const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct indices_statically_known_to_increase<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() () const { + return true; + } +}; +template <typename Index, std::size_t Rank> +struct indices_statically_known_to_increase<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() () const { + return true; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_eq<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_eq<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_ne<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_ne<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_gt<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_gt<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; + +template <typename Index, std::size_t Rank> +struct index_statically_lt<DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; +template <typename Index, std::size_t Rank> +struct index_statically_lt<const DimensionList<Index, Rank> > { + EIGEN_ALWAYS_INLINE bool operator() (const DenseIndex i, const DenseIndex value) const { + return false; + } +}; +#endif + +} // end namespace internal +} // end namespace Eigen + + +#endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSION_LIST_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h new file mode 100644 index 000000000..78feb85cd --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h @@ -0,0 +1,36 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_META_H +#define EIGEN_CXX11_TENSOR_TENSOR_META_H + +namespace Eigen { + +template<bool cond> struct Cond {}; + +template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T1& choose(Cond<true>, const T1& first, const T2&) { + return first; +} + +template<typename T1, typename T2> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T2& choose(Cond<false>, const T1&, const T2& second) { + return second; +} + +template <size_t n> struct max_n_1 { + static const size_t size = n; +}; +template <> struct max_n_1<0> { + static const size_t size = 1; +}; + +} // namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_META_H |