diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-06-25 14:31:16 -0700 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-06-30 18:53:55 +0000 |
commit | 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch) | |
tree | 5348c34ac0673d09fe97aea29770e7b236e85510 /unsupported/test | |
parent | 145e51516fdac7b30d22c11c6878c2805fc3d724 (diff) |
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for
`SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the
missing 1:N and 8:1.
We also add casting `Eigen::half` to/from `std::complex<T>`, which
was missing to make it consistent with `Eigen:bfloat16`, and
generalize the overload to work for any complex type.
Tests were added to `basicstuff`, `packetmath`, and
`cxx11_tensor_casts` to test all cast configurations.
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_casts.cpp | 81 |
1 files changed, 76 insertions, 5 deletions
diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index c4fe9a798..45456f3ef 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -8,6 +8,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" +#include "random_without_cast_overflow.h" #include <Eigen/CXX11/Tensor> @@ -104,12 +105,82 @@ static void test_small_to_big_type_cast() } } +template <typename FromType, typename ToType> +static void test_type_cast() { + Tensor<FromType, 2> ftensor(100, 200); + // Generate random values for a valid cast. + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value(); + } + } + + Tensor<ToType, 2> ttensor(100, 200); + ttensor = ftensor.template cast<ToType>(); + + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j)); + VERIFY_IS_APPROX(ttensor(i, j), ref); + } + } +} + +template<typename Scalar, typename EnableIf = void> +struct test_cast_runner { + static void run() { + test_type_cast<Scalar, bool>(); + test_type_cast<Scalar, int8_t>(); + test_type_cast<Scalar, int16_t>(); + test_type_cast<Scalar, int32_t>(); + test_type_cast<Scalar, int64_t>(); + test_type_cast<Scalar, uint8_t>(); + test_type_cast<Scalar, uint16_t>(); + test_type_cast<Scalar, uint32_t>(); + test_type_cast<Scalar, uint64_t>(); + test_type_cast<Scalar, half>(); + test_type_cast<Scalar, bfloat16>(); + test_type_cast<Scalar, float>(); + test_type_cast<Scalar, double>(); + test_type_cast<Scalar, std::complex<float>>(); + test_type_cast<Scalar, std::complex<double>>(); + } +}; + +// Only certain types allow cast from std::complex<>. +template<typename Scalar> +struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> { + static void run() { + test_type_cast<Scalar, half>(); + test_type_cast<Scalar, bfloat16>(); + test_type_cast<Scalar, std::complex<float>>(); + test_type_cast<Scalar, std::complex<double>>(); + } +}; + EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); + CALL_SUBTEST(test_simple_cast()); + CALL_SUBTEST(test_vectorized_cast()); + CALL_SUBTEST(test_float_to_int_cast()); + CALL_SUBTEST(test_big_to_small_type_cast()); + CALL_SUBTEST(test_small_to_big_type_cast()); + + CALL_SUBTEST(test_cast_runner<bool>::run()); + CALL_SUBTEST(test_cast_runner<int8_t>::run()); + CALL_SUBTEST(test_cast_runner<int16_t>::run()); + CALL_SUBTEST(test_cast_runner<int32_t>::run()); + CALL_SUBTEST(test_cast_runner<int64_t>::run()); + CALL_SUBTEST(test_cast_runner<uint8_t>::run()); + CALL_SUBTEST(test_cast_runner<uint16_t>::run()); + CALL_SUBTEST(test_cast_runner<uint32_t>::run()); + CALL_SUBTEST(test_cast_runner<uint64_t>::run()); + CALL_SUBTEST(test_cast_runner<half>::run()); + CALL_SUBTEST(test_cast_runner<bfloat16>::run()); + CALL_SUBTEST(test_cast_runner<float>::run()); + CALL_SUBTEST(test_cast_runner<double>::run()); + CALL_SUBTEST(test_cast_runner<std::complex<float>>::run()); + CALL_SUBTEST(test_cast_runner<std::complex<double>>::run()); + } |