aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJonas Harsch <jonas_harsch@web.de>2021-07-02 20:33:52 +0000
committerAntonio Sánchez <cantonios@google.com>2021-07-02 20:33:52 +0000
commitaab747021be5ed1a1e9667243d884eb72003599d (patch)
tree13d56e54624d49e4141661fc4ea04b6c5324160d
parentbbfc4d54cd863676b3ae874e25dbe150fb6d575c (diff)
Don't crash when attempting to shuffle an empty tensor.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h6
-rw-r--r--unsupported/test/cxx11_tensor_shuffling.cpp55
2 files changed, 59 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
index 0999815d7..e5e5efdee 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
@@ -142,7 +142,8 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
m_unshuffledInputStrides[i] =
m_unshuffledInputStrides[i - 1] * input_dims[i - 1];
m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
- m_fastOutputStrides[i] = internal::TensorIntDivisor<Index>(m_outputStrides[i]);
+ m_fastOutputStrides[i] = internal::TensorIntDivisor<Index>(
+ m_outputStrides[i] > 0 ? m_outputStrides[i] : Index(1));
}
} else {
m_unshuffledInputStrides[NumDims - 1] = 1;
@@ -151,7 +152,8 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
m_unshuffledInputStrides[i] =
m_unshuffledInputStrides[i + 1] * input_dims[i + 1];
m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
- m_fastOutputStrides[i] = internal::TensorIntDivisor<Index>(m_outputStrides[i]);
+ m_fastOutputStrides[i] = internal::TensorIntDivisor<Index>(
+ m_outputStrides[i] > 0 ? m_outputStrides[i] : Index(1));
}
}
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp
index 2ec85d2d4..89a64c021 100644
--- a/unsupported/test/cxx11_tensor_shuffling.cpp
+++ b/unsupported/test/cxx11_tensor_shuffling.cpp
@@ -215,6 +215,59 @@ static void test_shuffle_unshuffle()
}
+template <int DataLayout>
+static void test_empty_shuffling()
+{
+ Tensor<float, 4, DataLayout> tensor(2,3,0,7);
+ tensor.setRandom();
+ array<ptrdiff_t, 4> shuffles;
+ shuffles[0] = 0;
+ shuffles[1] = 1;
+ shuffles[2] = 2;
+ shuffles[3] = 3;
+
+ Tensor<float, 4, DataLayout> no_shuffle;
+ no_shuffle = tensor.shuffle(shuffles);
+
+ VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
+ VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
+ VERIFY_IS_EQUAL(no_shuffle.dimension(2), 0);
+ VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 0; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
+ }
+ }
+ }
+ }
+
+ shuffles[0] = 2;
+ shuffles[1] = 3;
+ shuffles[2] = 1;
+ shuffles[3] = 0;
+ Tensor<float, 4, DataLayout> shuffle;
+ shuffle = tensor.shuffle(shuffles);
+
+ VERIFY_IS_EQUAL(shuffle.dimension(0), 0);
+ VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
+ VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
+ VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 0; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
+ }
+ }
+ }
+ }
+}
+
+
EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
{
CALL_SUBTEST(test_simple_shuffling<ColMajor>());
@@ -225,4 +278,6 @@ EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
+ CALL_SUBTEST(test_empty_shuffling<ColMajor>());
+ CALL_SUBTEST(test_empty_shuffling<RowMajor>());
}