diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-07-29 15:01:21 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-07-29 15:01:21 -0700 |
commit | e1d28b7ea7ee2aad4603121d5e1bec0c4484c838 (patch) | |
tree | efb7046e83074e76eae2fac92c53832251857295 | |
parent | 0570594f2c0c9fd241ff76f741d034e1daf106f9 (diff) |
Added a test for shuffling
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h | 4 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_shuffling.cpp | 53 |
2 files changed, 49 insertions, 8 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h index c45530098..15a22aa1b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h @@ -67,7 +67,7 @@ class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType> : m_xpr(expr), m_shuffle(shuffle) {} EIGEN_DEVICE_FUNC - const Shuffle& shuffle() const { return m_shuffle; } + const Shuffle& shufflePermutation() const { return m_shuffle; } EIGEN_DEVICE_FUNC const typename internal::remove_all<typename XprType::Nested>::type& @@ -119,7 +119,7 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device> : m_impl(op.expression(), device) { const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); - const Shuffle& shuffle = op.shuffle(); + const Shuffle& shuffle = op.shufflePermutation(); for (int i = 0; i < NumDims; ++i) { m_dimensions[i] = input_dims[shuffle[i]]; } diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp index 2f7fd9e50..d11444a14 100644 --- a/unsupported/test/cxx11_tensor_shuffling.cpp +++ b/unsupported/test/cxx11_tensor_shuffling.cpp @@ -176,12 +176,53 @@ static void test_shuffling_as_value() } } + +template <int DataLayout> +static void test_shuffle_unshuffle() +{ + Tensor<float, 4, DataLayout> tensor(2,3,5,7); + tensor.setRandom(); + + // Choose a random permutation. + array<ptrdiff_t, 4> shuffles; + for (int i = 0; i < 4; ++i) { + shuffles[i] = i; + } + array<ptrdiff_t, 4> shuffles_inverse; + for (int i = 0; i < 4; ++i) { + const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3); + shuffles_inverse[shuffles[index]] = i; + std::swap(shuffles[i], shuffles[index]); + } + + Tensor<float, 4, DataLayout> shuffle; + shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse); + + VERIFY_IS_EQUAL(shuffle.dimension(0), 2); + VERIFY_IS_EQUAL(shuffle.dimension(1), 3); + VERIFY_IS_EQUAL(shuffle.dimension(2), 5); + VERIFY_IS_EQUAL(shuffle.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l)); + } + } + } + } +} + + void test_cxx11_tensor_shuffling() { - CALL_SUBTEST(test_simple_shuffling<ColMajor>()); - CALL_SUBTEST(test_simple_shuffling<RowMajor>()); - CALL_SUBTEST(test_expr_shuffling<ColMajor>()); - CALL_SUBTEST(test_expr_shuffling<RowMajor>()); - CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); - CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_simple_shuffling<ColMajor>()); + CALL_SUBTEST(test_simple_shuffling<RowMajor>()); + CALL_SUBTEST(test_expr_shuffling<ColMajor>()); + CALL_SUBTEST(test_expr_shuffling<RowMajor>()); + CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); + CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>()); } |