diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h | 50 |
1 files changed, 29 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h index c5a630105..0eb468fc0 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h @@ -114,20 +114,24 @@ struct tensor_static_symgroup_equality template<std::size_t NumIndices, typename... Gen> struct tensor_static_symgroup { - typedef StaticSGroup<NumIndices, Gen...> type; + typedef StaticSGroup<Gen...> type; constexpr static std::size_t size = type::static_size; }; -template<typename Index, std::size_t N, int... ii> -constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>) +template<typename Index, std::size_t N, int... ii, int... jj> +constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>) { - return {{ idx[ii]... }}; + return {{ idx[ii]..., idx[jj]... }}; } template<typename Index, int... ii> static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>) { - return {{ idx[ii]... }}; + std::vector<Index> result{{ idx[ii]... }}; + std::size_t target_size = idx.size(); + for (std::size_t i = result.size(); i < target_size; i++) + result.push_back(idx[i]); + return result; } template<typename T> struct tensor_static_symgroup_do_apply; @@ -135,32 +139,35 @@ template<typename T> struct tensor_static_symgroup_do_apply; template<typename first, typename... next> struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>> { - template<typename Op, typename RV, typename Index, std::size_t N, typename... Args> - static inline RV run(const std::array<Index, N>& idx, RV initial, Args&&... args) + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args> + static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args) { - initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...); - return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...); + static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices."); + typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices; + initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...); + return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...); } - template<typename Op, typename RV, typename Index, typename... Args> + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args> static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args) { + eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...); - return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...); + return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...); } }; template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)> struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>> { - template<typename Op, typename RV, typename Index, std::size_t N, typename... Args> - static inline RV run(const std::array<Index, N>&, RV initial, Args&&...) + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args> + static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...) { // do nothing return initial; } - template<typename Op, typename RV, typename Index, typename... Args> + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args> static inline RV run(const std::vector<Index>&, RV initial, Args&&...) { // do nothing @@ -170,9 +177,10 @@ struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HAC } // end namespace internal -template<std::size_t NumIndices, typename... Gen> +template<typename... Gen> class StaticSGroup { + constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value; typedef internal::group_theory::enumerate_group_elements< internal::tensor_static_symgroup_multiply, internal::tensor_static_symgroup_equality, @@ -182,20 +190,20 @@ class StaticSGroup typedef typename group_elements::type ge; public: constexpr inline StaticSGroup() {} - constexpr inline StaticSGroup(const StaticSGroup<NumIndices, Gen...>&) {} - constexpr inline StaticSGroup(StaticSGroup<NumIndices, Gen...>&&) {} + constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {} + constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {} - template<typename Op, typename RV, typename Index, typename... Args> - static inline RV apply(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args) + template<typename Op, typename RV, typename Index, std::size_t N, typename... Args> + static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args) { - return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...); + return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...); } template<typename Op, typename RV, typename Index, typename... Args> static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args) { eigen_assert(idx.size() == NumIndices); - return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...); + return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...); } constexpr static std::size_t static_size = ge::count; |