aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h')
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h50
1 files changed, 29 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
index c5a630105..0eb468fc0 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
@@ -114,20 +114,24 @@ struct tensor_static_symgroup_equality
template<std::size_t NumIndices, typename... Gen>
struct tensor_static_symgroup
{
- typedef StaticSGroup<NumIndices, Gen...> type;
+ typedef StaticSGroup<Gen...> type;
constexpr static std::size_t size = type::static_size;
};
-template<typename Index, std::size_t N, int... ii>
-constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>)
+template<typename Index, std::size_t N, int... ii, int... jj>
+constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>)
{
- return {{ idx[ii]... }};
+ return {{ idx[ii]..., idx[jj]... }};
}
template<typename Index, int... ii>
static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>)
{
- return {{ idx[ii]... }};
+ std::vector<Index> result{{ idx[ii]... }};
+ std::size_t target_size = idx.size();
+ for (std::size_t i = result.size(); i < target_size; i++)
+ result.push_back(idx[i]);
+ return result;
}
template<typename T> struct tensor_static_symgroup_do_apply;
@@ -135,32 +139,35 @@ template<typename T> struct tensor_static_symgroup_do_apply;
template<typename first, typename... next>
struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>>
{
- template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
- static inline RV run(const std::array<Index, N>& idx, RV initial, Args&&... args)
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
+ static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
{
- initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
- return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...);
+ static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices.");
+ typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices;
+ initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...);
+ return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
}
- template<typename Op, typename RV, typename Index, typename... Args>
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args)
{
+ eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
- return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...);
+ return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
}
};
template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)>
struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>>
{
- template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
- static inline RV run(const std::array<Index, N>&, RV initial, Args&&...)
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
+ static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...)
{
// do nothing
return initial;
}
- template<typename Op, typename RV, typename Index, typename... Args>
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
static inline RV run(const std::vector<Index>&, RV initial, Args&&...)
{
// do nothing
@@ -170,9 +177,10 @@ struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HAC
} // end namespace internal
-template<std::size_t NumIndices, typename... Gen>
+template<typename... Gen>
class StaticSGroup
{
+ constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
typedef internal::group_theory::enumerate_group_elements<
internal::tensor_static_symgroup_multiply,
internal::tensor_static_symgroup_equality,
@@ -182,20 +190,20 @@ class StaticSGroup
typedef typename group_elements::type ge;
public:
constexpr inline StaticSGroup() {}
- constexpr inline StaticSGroup(const StaticSGroup<NumIndices, Gen...>&) {}
- constexpr inline StaticSGroup(StaticSGroup<NumIndices, Gen...>&&) {}
+ constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {}
+ constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {}
- template<typename Op, typename RV, typename Index, typename... Args>
- static inline RV apply(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
+ template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
+ static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args)
{
- return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...);
+ return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
}
template<typename Op, typename RV, typename Index, typename... Args>
static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args)
{
eigen_assert(idx.size() == NumIndices);
- return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...);
+ return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
}
constexpr static std::size_t static_size = ge::count;