diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h | 42 |
1 files changed, 29 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h index b5738b778..0329278a9 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h @@ -15,7 +15,7 @@ namespace Eigen { class DynamicSGroup { public: - inline explicit DynamicSGroup(std::size_t numIndices) : m_numIndices(numIndices), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); } + inline explicit DynamicSGroup() : m_numIndices(1), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); } inline DynamicSGroup(const DynamicSGroup& o) : m_numIndices(o.m_numIndices), m_elements(o.m_elements), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { } inline DynamicSGroup(DynamicSGroup&& o) : m_numIndices(o.m_numIndices), m_elements(), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { std::swap(m_elements, o.m_elements); } inline DynamicSGroup& operator=(const DynamicSGroup& o) { m_numIndices = o.m_numIndices; m_elements = o.m_elements; m_generators = o.m_generators; m_globalFlags = o.m_globalFlags; return *this; } @@ -33,7 +33,7 @@ class DynamicSGroup template<typename Op, typename RV, typename Index, std::size_t N, typename... Args> inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args) const { - eigen_assert(N == m_numIndices); + eigen_assert(N >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); for (std::size_t i = 0; i < size(); i++) initial = Op::run(h_permute(i, idx, typename internal::gen_numeric_list<int, N>::type()), m_elements[i].flags, initial, std::forward<Args>(args)...); return initial; @@ -42,7 +42,7 @@ class DynamicSGroup template<typename Op, typename RV, typename Index, typename... Args> inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args) const { - eigen_assert(idx.size() == m_numIndices); + eigen_assert(idx.size() >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); for (std::size_t i = 0; i < size(); i++) initial = Op::run(h_permute(i, idx), m_elements[i].flags, initial, std::forward<Args>(args)...); return initial; @@ -77,7 +77,7 @@ class DynamicSGroup template<typename Index, std::size_t N, int... n> inline std::array<Index, N> h_permute(std::size_t which, const std::array<Index, N>& idx, internal::numeric_list<int, n...>) const { - return std::array<Index, N>{{ idx[m_elements[which].representation[n]]... }}; + return std::array<Index, N>{{ idx[n >= m_numIndices ? n : m_elements[which].representation[n]]... }}; } template<typename Index> @@ -87,6 +87,8 @@ class DynamicSGroup result.reserve(idx.size()); for (auto k : m_elements[which].representation) result.push_back(idx[k]); + for (std::size_t i = m_numIndices; i < idx.size(); i++) + result.push_back(idx[i]); return result; } @@ -135,18 +137,18 @@ class DynamicSGroup }; // dynamic symmetry group that auto-adds the template parameters in the constructor -template<std::size_t NumIndices, typename... Gen> +template<typename... Gen> class DynamicSGroupFromTemplateArgs : public DynamicSGroup { public: - inline DynamicSGroupFromTemplateArgs() : DynamicSGroup(NumIndices) + inline DynamicSGroupFromTemplateArgs() : DynamicSGroup() { add_all(internal::type_list<Gen...>()); } inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs const& other) : DynamicSGroup(other) { } inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs&& other) : DynamicSGroup(other) { } - inline DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& operator=(const DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& o) { DynamicSGroup::operator=(o); return *this; } - inline DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& operator=(DynamicSGroupFromTemplateArgs<NumIndices, Gen...>&& o) { DynamicSGroup::operator=(o); return *this; } + inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(const DynamicSGroupFromTemplateArgs<Gen...>& o) { DynamicSGroup::operator=(o); return *this; } + inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(DynamicSGroupFromTemplateArgs<Gen...>&& o) { DynamicSGroup::operator=(o); return *this; } private: template<typename Gen1, typename... GenNext> @@ -168,18 +170,32 @@ inline DynamicSGroup::GroupElement DynamicSGroup::mul(GroupElement g1, GroupElem GroupElement result; result.representation.reserve(m_numIndices); - for (std::size_t i = 0; i < m_numIndices; i++) - result.representation.push_back(g2.representation[g1.representation[i]]); + for (std::size_t i = 0; i < m_numIndices; i++) { + int v = g2.representation[g1.representation[i]]; + eigen_assert(v >= 0); + result.representation.push_back(v); + } result.flags = g1.flags ^ g2.flags; return result; } inline void DynamicSGroup::add(int one, int two, int flags) { - eigen_assert(one >= 0 && (std::size_t)one < m_numIndices); - eigen_assert(two >= 0 && (std::size_t)two < m_numIndices); + eigen_assert(one >= 0); + eigen_assert(two >= 0); eigen_assert(one != two); - Generator g{one, two ,flags}; + + if ((std::size_t)one >= m_numIndices || (std::size_t)two >= m_numIndices) { + std::size_t newNumIndices = (one > two) ? one : two + 1; + for (auto& gelem : m_elements) { + gelem.representation.reserve(newNumIndices); + for (std::size_t i = m_numIndices; i < newNumIndices; i++) + gelem.representation.push_back(i); + } + m_numIndices = newNumIndices; + } + + Generator g{one, two, flags}; GroupElement e = ge(g); /* special case for first generator */ |