aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-15 15:17:38 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-15 15:17:38 -0700
commit66f056776f1220ede1aa5cbfe058b88d6df3e359 (patch)
treea4330631913375478020054258515dde7a8e63fd
parent42705ba574e8c0a1764ef96e41831ed353b4057e (diff)
Add DSizes index type promotion
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h19
-rw-r--r--unsupported/test/cxx11_tensor_dimension.cpp18
2 files changed, 36 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
index 94871ef43..5de0d0de7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
@@ -290,10 +290,27 @@ struct DSizes : array<DenseIndex, NumDims> {
}
}
+#ifdef EIGEN_HAS_CXX11
+ // Enable DSizes index type promotion only if we are promoting to the
+ // larger type, e.g. allow to promote dimensions of type int to long.
+ template<typename OtherIndex,
+ typename std::enable_if<
+ std::is_same<
+ DenseIndex,
+ typename internal::promote_index_type<DenseIndex, OtherIndex>::type
+ >::value, int>::type = 0>
+ EIGEN_DEVICE_FUNC
+ explicit DSizes(const array<OtherIndex, NumDims>& other) {
+ for (int i = 0; i < NumDims; ++i) {
+ (*this)[i] = static_cast<DenseIndex>(other[i]);
+ }
+ }
+#endif // EIGEN_HAS_CXX11
+
#ifdef EIGEN_HAS_INDEX_LIST
template <typename FirstType, typename... OtherTypes>
EIGEN_DEVICE_FUNC
- DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
+ explicit DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
for (int i = 0; i < dimensions.count; ++i) {
(*this)[i] = dimensions[i];
}
diff --git a/unsupported/test/cxx11_tensor_dimension.cpp b/unsupported/test/cxx11_tensor_dimension.cpp
index 10364d4b4..26f8edd8a 100644
--- a/unsupported/test/cxx11_tensor_dimension.cpp
+++ b/unsupported/test/cxx11_tensor_dimension.cpp
@@ -60,10 +60,28 @@ static void test_rank_zero()
VERIFY_IS_EQUAL((int)dscalar.rank(), 0);
}
+static void test_index_type_promotion() {
+#ifdef EIGEN_HAS_CXX11
+ Eigen::DSizes<int, 3> src0(1, 2, 3);
+ Eigen::array<int, 3> src1 = {4, 5, 6};
+
+ Eigen::DSizes<long, 3> dst0(src0);
+ Eigen::DSizes<long, 3> dst1(src1);
+
+ VERIFY_IS_EQUAL(dst0[0], 1L);
+ VERIFY_IS_EQUAL(dst0[1], 2L);
+ VERIFY_IS_EQUAL(dst0[2], 3L);
+ VERIFY_IS_EQUAL(dst1[0], 4L);
+ VERIFY_IS_EQUAL(dst1[1], 5L);
+ VERIFY_IS_EQUAL(dst1[2], 6L);
+#endif // EIGEN_HAS_CXX11
+}
+
EIGEN_DECLARE_TEST(cxx11_tensor_dimension)
{
CALL_SUBTEST(test_dynamic_size());
CALL_SUBTEST(test_fixed_size());
CALL_SUBTEST(test_match());
CALL_SUBTEST(test_rank_zero());
+ CALL_SUBTEST(test_index_type_promotion());
}