aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/TensorSymmetry
diff options
context:
space:
mode:
authorGravatar Christian Seiler <christian@iwakd.de>2014-06-04 20:27:42 +0200
committerGravatar Christian Seiler <christian@iwakd.de>2014-06-04 20:27:42 +0200
commitea9943352368b990d27ba22eb8670287cf96302d (patch)
tree68511340c561f263b14a8569be0f926fe75bbb60 /unsupported/Eigen/CXX11/src/TensorSymmetry
parentcee62018fc38f5408e0afe497c37fade64ca15d0 (diff)
unsupported/TensorSymmetry: make symgroup construction autodetect number of indices
When constructing a symmetry group, make the code automatically detect the number of indices required from the indices of the group's generators. Also, allow the symmetry group to be applied to lists of indices that are larger than the number of indices of the symmetry group. Before: SGroup<4, Symmetry<0, 1>, Symmetry<2,3>> group; group.apply<SomeOp, int>(std::array<int,4>{{0, 1, 2, 3}}, 0); After: SGroup<Symmetry<0, 1>, Symmetry<2,3>> group; group.apply<SomeOp, int>(std::array<int,4>{{0, 1, 2, 3}}, 0); group.apply<SomeOp, int>(std::array<int,5>{{0, 1, 2, 3, 4}}, 0); This should make the symmetry group easier to use - especially if one wants to reuse the same symmetry group for different tensors of maybe different rank. static/runtime asserts remain for the case where the length of the index list to which a symmetry group is to be applied is too small.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/TensorSymmetry')
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h42
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h50
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h47
3 files changed, 95 insertions, 44 deletions
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
index b5738b778..0329278a9 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
@@ -15,7 +15,7 @@ namespace Eigen {
class DynamicSGroup
{
public:
- inline explicit DynamicSGroup(std::size_t numIndices) : m_numIndices(numIndices), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); }
+ inline explicit DynamicSGroup() : m_numIndices(1), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); }
inline DynamicSGroup(const DynamicSGroup& o) : m_numIndices(o.m_numIndices), m_elements(o.m_elements), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { }
inline DynamicSGroup(DynamicSGroup&& o) : m_numIndices(o.m_numIndices), m_elements(), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { std::swap(m_elements, o.m_elements); }
inline DynamicSGroup& operator=(const DynamicSGroup& o) { m_numIndices = o.m_numIndices; m_elements = o.m_elements; m_generators = o.m_generators; m_globalFlags = o.m_globalFlags; return *this; }
@@ -33,7 +33,7 @@ class DynamicSGroup
template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args) const
{
- eigen_assert(N == m_numIndices);
+ eigen_assert(N >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
for (std::size_t i = 0; i < size(); i++)
initial = Op::run(h_permute(i, idx, typename internal::gen_numeric_list<int, N>::type()), m_elements[i].flags, initial, std::forward<Args>(args)...);
return initial;
@@ -42,7 +42,7 @@ class DynamicSGroup
template<typename Op, typename RV, typename Index, typename... Args>
inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args) const
{
- eigen_assert(idx.size() == m_numIndices);
+ eigen_assert(idx.size() >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
for (std::size_t i = 0; i < size(); i++)
initial = Op::run(h_permute(i, idx), m_elements[i].flags, initial, std::forward<Args>(args)...);
return initial;
@@ -77,7 +77,7 @@ class DynamicSGroup
template<typename Index, std::size_t N, int... n>
inline std::array<Index, N> h_permute(std::size_t which, const std::array<Index, N>& idx, internal::numeric_list<int, n...>) const
{
- return std::array<Index, N>{{ idx[m_elements[which].representation[n]]... }};
+ return std::array<Index, N>{{ idx[n >= m_numIndices ? n : m_elements[which].representation[n]]... }};
}
template<typename Index>
@@ -87,6 +87,8 @@ class DynamicSGroup
result.reserve(idx.size());
for (auto k : m_elements[which].representation)
result.push_back(idx[k]);
+ for (std::size_t i = m_numIndices; i < idx.size(); i++)
+ result.push_back(idx[i]);
return result;
}
@@ -135,18 +137,18 @@ class DynamicSGroup
};
// dynamic symmetry group that auto-adds the template parameters in the constructor
-template<std::size_t NumIndices, typename... Gen>
+template<typename... Gen>
class DynamicSGroupFromTemplateArgs : public DynamicSGroup
{
public:
- inline DynamicSGroupFromTemplateArgs() : DynamicSGroup(NumIndices)
+ inline DynamicSGroupFromTemplateArgs() : DynamicSGroup()
{
add_all(internal::type_list<Gen...>());
}
inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs const& other) : DynamicSGroup(other) { }
inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs&& other) : DynamicSGroup(other) { }
- inline DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& operator=(const DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& o) { DynamicSGroup::operator=(o); return *this; }
- inline DynamicSGroupFromTemplateArgs<NumIndices, Gen...>& operator=(DynamicSGroupFromTemplateArgs<NumIndices, Gen...>&& o) { DynamicSGroup::operator=(o); return *this; }
+ inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(const DynamicSGroupFromTemplateArgs<Gen...>& o) { DynamicSGroup::operator=(o); return *this; }
+ inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(DynamicSGroupFromTemplateArgs<Gen...>&& o) { DynamicSGroup::operator=(o); return *this; }
private:
template<typename Gen1, typename... GenNext>
@@ -168,18 +170,32 @@ inline DynamicSGroup::GroupElement DynamicSGroup::mul(GroupElement g1, GroupElem
GroupElement result;
result.representation.reserve(m_numIndices);
- for (std::size_t i = 0; i < m_numIndices; i++)
- result.representation.push_back(g2.representation[g1.representation[i]]);
+ for (std::size_t i = 0; i < m_numIndices; i++) {
+ int v = g2.representation[g1.representation[i]];
+ eigen_assert(v >= 0);
+ result.representation.push_back(v);
+ }
result.flags = g1.flags ^ g2.flags;
return result;
}
inline void DynamicSGroup::add(int one, int two, int flags)
{
- eigen_assert(one >= 0 && (std::size_t)one < m_numIndices);
- eigen_assert(two >= 0 && (std::size_t)two < m_numIndices);
+ eigen_assert(one >= 0);
+ eigen_assert(two >= 0);
eigen_assert(one != two);
- Generator g{one, two ,flags};
+
+ if ((std::size_t)one >= m_numIndices || (std::size_t)two >= m_numIndices) {
+ std::size_t newNumIndices = (one > two) ? one : two + 1;
+ for (auto& gelem : m_elements) {
+ gelem.representation.reserve(newNumIndices);
+ for (std::size_t i = m_numIndices; i < newNumIndices; i++)
+ gelem.representation.push_back(i);
+ }
+ m_numIndices = newNumIndices;
+ }
+
+ Generator g{one, two, flags};
GroupElement e = ge(g);
/* special case for first generator */
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
index c5a630105..0eb468fc0 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
@@ -114,20 +114,24 @@ struct tensor_static_symgroup_equality
template<std::size_t NumIndices, typename... Gen>
struct tensor_static_symgroup
{
- typedef StaticSGroup<NumIndices, Gen...> type;
+ typedef StaticSGroup<Gen...> type;
constexpr static std::size_t size = type::static_size;
};
-template<typename Index, std::size_t N, int... ii>
-constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>)
+template<typename Index, std::size_t N, int... ii, int... jj>
+constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>)
{
- return {{ idx[ii]... }};
+ return {{ idx[ii]..., idx[jj]... }};
}
template<typename Index, int... ii>
static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>)
{
- return {{ idx[ii]... }};
+ std::vector<Index> result{{ idx[ii]... }};
+ std::size_t target_size = idx.size();
+ for (std::size_t i = result.size(); i < target_size; i++)
+ result.push_back(idx[i]);
+ return result;
}
template<typename T> struct tensor_static_symgroup_do_apply;
@@ -135,32 +139,35 @@ template<typename T> struct tensor_static_symgroup_do_apply;
template<typename first, typename... next>
struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>>
{
- template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
- static inline RV run(const std::array<Index, N>& idx, RV initial, Args&&... args)
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
+ static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
{
- initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
- return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...);
+ static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices.");
+ typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices;
+ initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...);
+ return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
}
- template<typename Op, typename RV, typename Index, typename... Args>
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args)
{
+ eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
- return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op>(idx, initial, args...);
+ return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
}
};
template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)>
struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>>
{
- template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
- static inline RV run(const std::array<Index, N>&, RV initial, Args&&...)
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
+ static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...)
{
// do nothing
return initial;
}
- template<typename Op, typename RV, typename Index, typename... Args>
+ template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
static inline RV run(const std::vector<Index>&, RV initial, Args&&...)
{
// do nothing
@@ -170,9 +177,10 @@ struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HAC
} // end namespace internal
-template<std::size_t NumIndices, typename... Gen>
+template<typename... Gen>
class StaticSGroup
{
+ constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
typedef internal::group_theory::enumerate_group_elements<
internal::tensor_static_symgroup_multiply,
internal::tensor_static_symgroup_equality,
@@ -182,20 +190,20 @@ class StaticSGroup
typedef typename group_elements::type ge;
public:
constexpr inline StaticSGroup() {}
- constexpr inline StaticSGroup(const StaticSGroup<NumIndices, Gen...>&) {}
- constexpr inline StaticSGroup(StaticSGroup<NumIndices, Gen...>&&) {}
+ constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {}
+ constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {}
- template<typename Op, typename RV, typename Index, typename... Args>
- static inline RV apply(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
+ template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
+ static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args)
{
- return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...);
+ return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
}
template<typename Op, typename RV, typename Index, typename... Args>
static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args)
{
eigen_assert(idx.size() == NumIndices);
- return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV>(idx, initial, args...);
+ return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
}
constexpr static std::size_t static_size = ge::count;
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
index f0813086a..f1ccc33ef 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
@@ -30,6 +30,7 @@ template<std::size_t NumIndices, typename... Sym> struct tenso
template<bool instantiate, std::size_t NumIndices, typename... Sym> struct tensor_static_symgroup_if;
template<typename Tensor_> struct tensor_symmetry_calculate_flags;
template<typename Tensor_> struct tensor_symmetry_assign_value;
+template<typename... Sym> struct tensor_symmetry_num_indices;
} // end namespace internal
@@ -94,7 +95,7 @@ class DynamicSGroup;
* This class is a child class of DynamicSGroup. It uses the template arguments
* specified to initialize itself.
*/
-template<std::size_t NumIndices, typename... Gen>
+template<typename... Gen>
class DynamicSGroupFromTemplateArgs;
/** \class StaticSGroup
@@ -116,7 +117,7 @@ class DynamicSGroupFromTemplateArgs;
* group becomes too large. (In that case, unrolling may not even be
* beneficial.)
*/
-template<std::size_t NumIndices, typename... Gen>
+template<typename... Gen>
class StaticSGroup;
/** \class SGroup
@@ -131,24 +132,50 @@ class StaticSGroup;
* \sa StaticSGroup
* \sa DynamicSGroup
*/
-template<std::size_t NumIndices, typename... Gen>
-class SGroup : public internal::tensor_symmetry_pre_analysis<NumIndices, Gen...>::root_type
+template<typename... Gen>
+class SGroup : public internal::tensor_symmetry_pre_analysis<internal::tensor_symmetry_num_indices<Gen...>::value, Gen...>::root_type
{
public:
+ constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
typedef typename internal::tensor_symmetry_pre_analysis<NumIndices, Gen...>::root_type Base;
// make standard constructors + assignment operators public
inline SGroup() : Base() { }
- inline SGroup(const SGroup<NumIndices, Gen...>& other) : Base(other) { }
- inline SGroup(SGroup<NumIndices, Gen...>&& other) : Base(other) { }
- inline SGroup<NumIndices, Gen...>& operator=(const SGroup<NumIndices, Gen...>& other) { Base::operator=(other); return *this; }
- inline SGroup<NumIndices, Gen...>& operator=(SGroup<NumIndices, Gen...>&& other) { Base::operator=(other); return *this; }
+ inline SGroup(const SGroup<Gen...>& other) : Base(other) { }
+ inline SGroup(SGroup<Gen...>&& other) : Base(other) { }
+ inline SGroup<Gen...>& operator=(const SGroup<Gen...>& other) { Base::operator=(other); return *this; }
+ inline SGroup<Gen...>& operator=(SGroup<Gen...>&& other) { Base::operator=(other); return *this; }
// all else is defined in the base class
};
namespace internal {
+template<typename... Sym> struct tensor_symmetry_num_indices
+{
+ constexpr static std::size_t value = 1;
+};
+
+template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...>
+{
+private:
+ constexpr static std::size_t One = static_cast<std::size_t>(One_);
+ constexpr static std::size_t Two = static_cast<std::size_t>(Two_);
+ constexpr static std::size_t Three = tensor_symmetry_num_indices<Sym...>::value;
+
+ // don't use std::max, since it's not constexpr until C++14...
+ constexpr static std::size_t maxOneTwoPlusOne = ((One > Two) ? One : Two) + 1;
+public:
+ constexpr static std::size_t value = (maxOneTwoPlusOne > Three) ? maxOneTwoPlusOne : Three;
+};
+
+template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<AntiSymmetry<One_, Two_>, Sym...>
+ : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
+template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<Hermiticity<One_, Two_>, Sym...>
+ : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
+template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<AntiHermiticity<One_, Two_>, Sym...>
+ : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
+
/** \internal
*
* \class tensor_symmetry_pre_analysis
@@ -199,7 +226,7 @@ namespace internal {
template<std::size_t NumIndices>
struct tensor_symmetry_pre_analysis<NumIndices>
{
- typedef StaticSGroup<NumIndices> root_type;
+ typedef StaticSGroup<> root_type;
};
template<std::size_t NumIndices, typename Gen_, typename... Gens_>
@@ -212,7 +239,7 @@ struct tensor_symmetry_pre_analysis<NumIndices, Gen_, Gens_...>
typedef typename conditional<
possible_size == 0 || possible_size >= max_static_elements,
- DynamicSGroupFromTemplateArgs<NumIndices, Gen_, Gens_...>,
+ DynamicSGroupFromTemplateArgs<Gen_, Gens_...>,
typename helper::type
>::type root_type;
};