aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Christian Seiler <christian@iwakd.de>2014-06-04 20:44:22 +0200
committerGravatar Christian Seiler <christian@iwakd.de>2014-06-04 20:44:22 +0200
commit96cb58fa3b83448fcb2af2d131434a7ac10b915c (patch)
tree531c23ae8ad7b265dc256d6dcaf178ea2eec231c /unsupported
parentea9943352368b990d27ba22eb8670287cf96302d (diff)
unsupported/TensorSymmetry: factor out completely from Tensor module
Remove the symCoeff() method of the the Tensor module and move the functionality into a new operator() of the symmetry classes. This makes the Tensor module now completely self-contained without symmetry support (even though previously it was only a forward declaration and a otherwise harmless trivial templated method) and also removes the inconsistency with the rest of eigen w.r.t. the method's naming scheme.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/Tensor.h15
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h13
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h13
-rw-r--r--unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h2
-rw-r--r--unsupported/test/cxx11_tensor_symmetry.cpp10
5 files changed, 32 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
index c6216e14c..70ca1433f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
@@ -91,9 +91,6 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices);
}
};
-
-/* Forward-declaration required for the symmetry support. */
-template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter;
} // end namespace internal
template<typename Scalar_, std::size_t NumIndices_, int Options_>
@@ -285,18 +282,6 @@ class Tensor
#endif
}
- template<typename Symmetry_, typename... IndexTypes>
- internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, Index firstIndex, IndexTypes... otherIndices)
- {
- return symCoeff(symmetry, std::array<Index, NumIndices>{{firstIndex, otherIndices...}});
- }
-
- template<typename Symmetry_, typename... IndexTypes>
- internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, std::array<Index, NumIndices> const& indices)
- {
- return internal::tensor_symmetry_value_setter<Self, Symmetry_>(*this, symmetry, indices);
- }
-
protected:
bool checkIndexRange(const std::array<Index, NumIndices>& indices) const
{
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
index 0329278a9..bc4f2025f 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h
@@ -50,6 +50,19 @@ class DynamicSGroup
inline int globalFlags() const { return m_globalFlags; }
inline std::size_t size() const { return m_elements.size(); }
+
+ template<typename Tensor_, typename... IndexTypes>
+ inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
+ {
+ static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
+ return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
+ }
+
+ template<typename Tensor_>
+ inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
+ {
+ return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices);
+ }
private:
struct GroupElement {
std::vector<int> representation;
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
index 0eb468fc0..942293bd7 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h
@@ -212,6 +212,19 @@ class StaticSGroup
return ge::count;
}
constexpr static inline int globalFlags() { return group_elements::global_flags; }
+
+ template<typename Tensor_, typename... IndexTypes>
+ inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
+ {
+ static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
+ return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
+ }
+
+ template<typename Tensor_>
+ inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
+ {
+ return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
+ }
};
} // end namespace Eigen
diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
index f1ccc33ef..879d6cd77 100644
--- a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
+++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h
@@ -293,7 +293,7 @@ struct tensor_symmetry_calculate_flags
}
};
-template<typename Tensor_, typename Symmetry_, int Flags>
+template<typename Tensor_, typename Symmetry_, int Flags = 0>
class tensor_symmetry_value_setter
{
public:
diff --git a/unsupported/test/cxx11_tensor_symmetry.cpp b/unsupported/test/cxx11_tensor_symmetry.cpp
index 2a1669995..d680e9b3b 100644
--- a/unsupported/test/cxx11_tensor_symmetry.cpp
+++ b/unsupported/test/cxx11_tensor_symmetry.cpp
@@ -661,7 +661,7 @@ static void test_tensor_epsilon()
Tensor<int, 3> epsilon(3,3,3);
epsilon.setZero();
- epsilon.symCoeff(sym, 0, 1, 2) = 1;
+ sym(epsilon, 0, 1, 2) = 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
@@ -683,7 +683,7 @@ static void test_tensor_sym()
for (int k = l; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j; i < 10; i++) {
- t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
+ sym(t, i, j, k, l) = (i + j) * (k + l);
}
}
}
@@ -712,7 +712,7 @@ static void test_tensor_asym()
for (int k = l + 1; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j + 1; i < 10; i++) {
- t.symCoeff(sym, i, j, k, l) = ((i * j) + (k * l));
+ sym(t, i, j, k, l) = ((i * j) + (k * l));
}
}
}
@@ -751,7 +751,7 @@ static void test_tensor_dynsym()
for (int k = l; k < 10; k++) {
for (int j = 0; j < 10; j++) {
for (int i = j; i < 10; i++) {
- t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
+ sym(t, i, j, k, l) = (i + j) * (k + l);
}
}
}
@@ -787,7 +787,7 @@ static void test_tensor_randacc()
std::swap(i, j);
if (k < l)
std::swap(k, l);
- t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
+ sym(t, i, j, k, l) = (i + j) * (k + l);
}
for (int l = 0; l < 10; l++) {