diff options
Diffstat (limited to 'third_party/eigen3/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h')
-rw-r--r-- | third_party/eigen3/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h | 236 |
1 files changed, 236 insertions, 0 deletions
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h new file mode 100644 index 0000000000..942293bd71 --- /dev/null +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h @@ -0,0 +1,236 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2013 Christian Seiler <christian@iwakd.de> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H +#define EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H + +namespace Eigen { + +namespace internal { + +template<typename list> struct tensor_static_symgroup_permutate; + +template<int... nn> +struct tensor_static_symgroup_permutate<numeric_list<int, nn...>> +{ + constexpr static std::size_t N = sizeof...(nn); + + template<typename T> + constexpr static inline std::array<T, N> run(const std::array<T, N>& indices) + { + return {{indices[nn]...}}; + } +}; + +template<typename indices_, int flags_> +struct tensor_static_symgroup_element +{ + typedef indices_ indices; + constexpr static int flags = flags_; +}; + +template<typename Gen, int N> +struct tensor_static_symgroup_element_ctor +{ + typedef tensor_static_symgroup_element< + typename gen_numeric_list_swapped_pair<int, N, Gen::One, Gen::Two>::type, + Gen::Flags + > type; +}; + +template<int N> +struct tensor_static_symgroup_identity_ctor +{ + typedef tensor_static_symgroup_element< + typename gen_numeric_list<int, N>::type, + 0 + > type; +}; + +template<typename iib> +struct tensor_static_symgroup_multiply_helper +{ + template<int... iia> + constexpr static inline numeric_list<int, get<iia, iib>::value...> helper(numeric_list<int, iia...>) { + return numeric_list<int, get<iia, iib>::value...>(); + } +}; + +template<typename A, typename B> +struct tensor_static_symgroup_multiply +{ + private: + typedef typename A::indices iia; + typedef typename B::indices iib; + constexpr static int ffa = A::flags; + constexpr static int ffb = B::flags; + + public: + static_assert(iia::count == iib::count, "Cannot multiply symmetry elements with different number of indices."); + + typedef tensor_static_symgroup_element< + decltype(tensor_static_symgroup_multiply_helper<iib>::helper(iia())), + ffa ^ ffb + > type; +}; + +template<typename A, typename B> +struct tensor_static_symgroup_equality +{ + typedef typename A::indices iia; + typedef typename B::indices iib; + constexpr static int ffa = A::flags; + constexpr static int ffb = B::flags; + static_assert(iia::count == iib::count, "Cannot compare symmetry elements with different number of indices."); + + constexpr static bool value = is_same<iia, iib>::value; + + private: + /* this should be zero if they are identical, or else the tensor + * will be forced to be pure real, pure imaginary or even pure zero + */ + constexpr static int flags_cmp_ = ffa ^ ffb; + + /* either they are not equal, then we don't care whether the flags + * match, or they are equal, and then we have to check + */ + constexpr static bool is_zero = value && flags_cmp_ == NegationFlag; + constexpr static bool is_real = value && flags_cmp_ == ConjugationFlag; + constexpr static bool is_imag = value && flags_cmp_ == (NegationFlag | ConjugationFlag); + + public: + constexpr static int global_flags = + (is_real ? GlobalRealFlag : 0) | + (is_imag ? GlobalImagFlag : 0) | + (is_zero ? GlobalZeroFlag : 0); +}; + +template<std::size_t NumIndices, typename... Gen> +struct tensor_static_symgroup +{ + typedef StaticSGroup<Gen...> type; + constexpr static std::size_t size = type::static_size; +}; + +template<typename Index, std::size_t N, int... ii, int... jj> +constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>) +{ + return {{ idx[ii]..., idx[jj]... }}; +} + +template<typename Index, int... ii> +static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>) +{ + std::vector<Index> result{{ idx[ii]... }}; + std::size_t target_size = idx.size(); + for (std::size_t i = result.size(); i < target_size; i++) + result.push_back(idx[i]); + return result; +} + +template<typename T> struct tensor_static_symgroup_do_apply; + +template<typename first, typename... next> +struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>> +{ + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args> + static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args) + { + static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices."); + typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices; + initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...); + return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...); + } + + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args> + static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args) + { + eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); + initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...); + return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...); + } +}; + +template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)> +struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>> +{ + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args> + static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...) + { + // do nothing + return initial; + } + + template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args> + static inline RV run(const std::vector<Index>&, RV initial, Args&&...) + { + // do nothing + return initial; + } +}; + +} // end namespace internal + +template<typename... Gen> +class StaticSGroup +{ + constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value; + typedef internal::group_theory::enumerate_group_elements< + internal::tensor_static_symgroup_multiply, + internal::tensor_static_symgroup_equality, + typename internal::tensor_static_symgroup_identity_ctor<NumIndices>::type, + internal::type_list<typename internal::tensor_static_symgroup_element_ctor<Gen, NumIndices>::type...> + > group_elements; + typedef typename group_elements::type ge; + public: + constexpr inline StaticSGroup() {} + constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {} + constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {} + + template<typename Op, typename RV, typename Index, std::size_t N, typename... Args> + static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args) + { + return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...); + } + + template<typename Op, typename RV, typename Index, typename... Args> + static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args) + { + eigen_assert(idx.size() == NumIndices); + return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...); + } + + constexpr static std::size_t static_size = ge::count; + + constexpr static inline std::size_t size() { + return ge::count; + } + constexpr static inline int globalFlags() { return group_elements::global_flags; } + + template<typename Tensor_, typename... IndexTypes> + inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const + { + static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}}); + } + + template<typename Tensor_> + inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const + { + return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices); + } +}; + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H + +/* + * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle; + */ |