diff options
5 files changed, 32 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index c6216e14c..70ca1433f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -91,9 +91,6 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices); } }; - -/* Forward-declaration required for the symmetry support. */ -template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter; } // end namespace internal template<typename Scalar_, std::size_t NumIndices_, int Options_> @@ -285,18 +282,6 @@ class Tensor #endif } - template<typename Symmetry_, typename... IndexTypes> - internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, Index firstIndex, IndexTypes... otherIndices) - { - return symCoeff(symmetry, std::array<Index, NumIndices>{{firstIndex, otherIndices...}}); - } - - template<typename Symmetry_, typename... IndexTypes> - internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, std::array<Index, NumIndices> const& indices) - { - return internal::tensor_symmetry_value_setter<Self, Symmetry_>(*this, symmetry, indices); - } - protected: bool checkIndexRange(const std::array<Index, NumIndices>& indices) const { diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h index 0329278a9..bc4f2025f 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h @@ -50,6 +50,19 @@ class DynamicSGroup inline int globalFlags() const { return m_globalFlags; } inline std::size_t size() const { return m_elements.size(); } + + template<typename Tensor_, typename... IndexTypes> + inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const + { + static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}}); + } + + template<typename Tensor_> + inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const + { + return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices); + } private: struct GroupElement { std::vector<int> representation; diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h index 0eb468fc0..942293bd7 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h @@ -212,6 +212,19 @@ class StaticSGroup return ge::count; } constexpr static inline int globalFlags() { return group_elements::global_flags; } + + template<typename Tensor_, typename... IndexTypes> + inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const + { + static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}}); + } + + template<typename Tensor_> + inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const + { + return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices); + } }; } // end namespace Eigen diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h index f1ccc33ef..879d6cd77 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h @@ -293,7 +293,7 @@ struct tensor_symmetry_calculate_flags } }; -template<typename Tensor_, typename Symmetry_, int Flags> +template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter { public: diff --git a/unsupported/test/cxx11_tensor_symmetry.cpp b/unsupported/test/cxx11_tensor_symmetry.cpp index 2a1669995..d680e9b3b 100644 --- a/unsupported/test/cxx11_tensor_symmetry.cpp +++ b/unsupported/test/cxx11_tensor_symmetry.cpp @@ -661,7 +661,7 @@ static void test_tensor_epsilon() Tensor<int, 3> epsilon(3,3,3); epsilon.setZero(); - epsilon.symCoeff(sym, 0, 1, 2) = 1; + sym(epsilon, 0, 1, 2) = 1; for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { @@ -683,7 +683,7 @@ static void test_tensor_sym() for (int k = l; k < 10; k++) { for (int j = 0; j < 10; j++) { for (int i = j; i < 10; i++) { - t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l); + sym(t, i, j, k, l) = (i + j) * (k + l); } } } @@ -712,7 +712,7 @@ static void test_tensor_asym() for (int k = l + 1; k < 10; k++) { for (int j = 0; j < 10; j++) { for (int i = j + 1; i < 10; i++) { - t.symCoeff(sym, i, j, k, l) = ((i * j) + (k * l)); + sym(t, i, j, k, l) = ((i * j) + (k * l)); } } } @@ -751,7 +751,7 @@ static void test_tensor_dynsym() for (int k = l; k < 10; k++) { for (int j = 0; j < 10; j++) { for (int i = j; i < 10; i++) { - t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l); + sym(t, i, j, k, l) = (i + j) * (k + l); } } } @@ -787,7 +787,7 @@ static void test_tensor_randacc() std::swap(i, j); if (k < l) std::swap(k, l); - t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l); + sym(t, i, j, k, l) = (i + j) * (k + l); } for (int l = 0; l < 10; l++) { |