From 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 25 Jun 2020 14:31:16 -0700 Subject: Fix tensor casts for large packets and casts to/from std::complex The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations. --- unsupported/test/cxx11_tensor_casts.cpp | 81 +++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 5 deletions(-) (limited to 'unsupported/test') diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index c4fe9a798..45456f3ef 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -8,6 +8,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" +#include "random_without_cast_overflow.h" #include @@ -104,12 +105,82 @@ static void test_small_to_big_type_cast() } } +template +static void test_type_cast() { + Tensor ftensor(100, 200); + // Generate random values for a valid cast. + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + ftensor(i, j) = internal::random_without_cast_overflow::value(); + } + } + + Tensor ttensor(100, 200); + ttensor = ftensor.template cast(); + + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + const ToType ref = internal::cast(ftensor(i, j)); + VERIFY_IS_APPROX(ttensor(i, j), ref); + } + } +} + +template +struct test_cast_runner { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + +// Only certain types allow cast from std::complex<>. +template +struct test_cast_runner::IsComplex>::type> { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); + CALL_SUBTEST(test_simple_cast()); + CALL_SUBTEST(test_vectorized_cast()); + CALL_SUBTEST(test_float_to_int_cast()); + CALL_SUBTEST(test_big_to_small_type_cast()); + CALL_SUBTEST(test_small_to_big_type_cast()); + + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner>::run()); + CALL_SUBTEST(test_cast_runner>::run()); + } -- cgit v1.2.3