aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-07-29 15:01:21 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-07-29 15:01:21 -0700
commite1d28b7ea7ee2aad4603121d5e1bec0c4484c838 (patch)
treeefb7046e83074e76eae2fac92c53832251857295
parent0570594f2c0c9fd241ff76f741d034e1daf106f9 (diff)
Added a test for shuffling
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h4
-rw-r--r--unsupported/test/cxx11_tensor_shuffling.cpp53
2 files changed, 49 insertions, 8 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
index c45530098..15a22aa1b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
@@ -67,7 +67,7 @@ class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType>
: m_xpr(expr), m_shuffle(shuffle) {}
EIGEN_DEVICE_FUNC
- const Shuffle& shuffle() const { return m_shuffle; }
+ const Shuffle& shufflePermutation() const { return m_shuffle; }
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename XprType::Nested>::type&
@@ -119,7 +119,7 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
: m_impl(op.expression(), device)
{
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
- const Shuffle& shuffle = op.shuffle();
+ const Shuffle& shuffle = op.shufflePermutation();
for (int i = 0; i < NumDims; ++i) {
m_dimensions[i] = input_dims[shuffle[i]];
}
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp
index 2f7fd9e50..d11444a14 100644
--- a/unsupported/test/cxx11_tensor_shuffling.cpp
+++ b/unsupported/test/cxx11_tensor_shuffling.cpp
@@ -176,12 +176,53 @@ static void test_shuffling_as_value()
}
}
+
+template <int DataLayout>
+static void test_shuffle_unshuffle()
+{
+ Tensor<float, 4, DataLayout> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ // Choose a random permutation.
+ array<ptrdiff_t, 4> shuffles;
+ for (int i = 0; i < 4; ++i) {
+ shuffles[i] = i;
+ }
+ array<ptrdiff_t, 4> shuffles_inverse;
+ for (int i = 0; i < 4; ++i) {
+ const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
+ shuffles_inverse[shuffles[index]] = i;
+ std::swap(shuffles[i], shuffles[index]);
+ }
+
+ Tensor<float, 4, DataLayout> shuffle;
+ shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
+
+ VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
+ VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
+ VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
+ VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 5; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
+ }
+ }
+ }
+ }
+}
+
+
void test_cxx11_tensor_shuffling()
{
- CALL_SUBTEST(test_simple_shuffling<ColMajor>());
- CALL_SUBTEST(test_simple_shuffling<RowMajor>());
- CALL_SUBTEST(test_expr_shuffling<ColMajor>());
- CALL_SUBTEST(test_expr_shuffling<RowMajor>());
- CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
- CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
+ CALL_SUBTEST(test_simple_shuffling<ColMajor>());
+ CALL_SUBTEST(test_simple_shuffling<RowMajor>());
+ CALL_SUBTEST(test_expr_shuffling<ColMajor>());
+ CALL_SUBTEST(test_expr_shuffling<RowMajor>());
+ CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
+ CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
+ CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
+ CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
}