diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | 17 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_dimension.cpp | 4 |
2 files changed, 14 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 7a1d40d7d..2b5de4f55 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -401,15 +401,21 @@ template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::si #endif -template <typename Dims1, typename Dims2, size_t n> +template <typename Dims1, typename Dims2, size_t n, size_t m> struct sizes_match_up_to_dim { static inline bool run(Dims1& dims1, Dims2& dims2) { + return false; + } +}; +template <typename Dims1, typename Dims2, size_t n> +struct sizes_match_up_to_dim<Dims1, Dims2, n, n> { + static inline bool run(Dims1& dims1, Dims2& dims2) { return (array_get<n>(dims1) == array_get<n>(dims2)) & - sizes_match_up_to_dim<Dims1, Dims2, n-1>::run(dims1, dims2); + sizes_match_up_to_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2); } }; template <typename Dims1, typename Dims2> -struct sizes_match_up_to_dim<Dims1, Dims2, 0> { +struct sizes_match_up_to_dim<Dims1, Dims2, 0, 0> { static inline bool run(Dims1& dims1, Dims2& dims2) { return (array_get<0>(dims1) == array_get<0>(dims2)); } @@ -420,10 +426,7 @@ struct sizes_match_up_to_dim<Dims1, Dims2, 0> { template <typename Dims1, typename Dims2> bool dimensions_match(Dims1& dims1, Dims2& dims2) { - if (static_cast<size_t>(internal::array_size<Dims1>::value) != static_cast<size_t>(internal::array_size<Dims2>::value)) { - return false; - } - return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1>::run(dims1, dims2); + return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1, internal::array_size<Dims2>::value-1>::run(dims1, dims2); } } // end namespace Eigen diff --git a/unsupported/test/cxx11_tensor_dimension.cpp b/unsupported/test/cxx11_tensor_dimension.cpp index 247d312ae..22c58450c 100644 --- a/unsupported/test/cxx11_tensor_dimension.cpp +++ b/unsupported/test/cxx11_tensor_dimension.cpp @@ -43,6 +43,10 @@ static void test_match() Eigen::DSizes<int, 3> dyn(2,3,7); Eigen::Sizes<2,3,7> stat; VERIFY_IS_EQUAL(Eigen::dimensions_match(dyn, stat), true); + + Eigen::DSizes<int, 3> dyn1(2,3,7); + Eigen::DSizes<int, 2> dyn2(2,3); + VERIFY_IS_EQUAL(Eigen::dimensions_match(dyn, stat), false); } |