From 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 25 Jun 2020 14:31:16 -0700 Subject: Fix tensor casts for large packets and casts to/from std::complex The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations. --- .../Eigen/CXX11/src/Tensor/TensorConversion.h | 33 ++++++++- unsupported/test/cxx11_tensor_casts.cpp | 81 ++++++++++++++++++++-- 2 files changed, 107 insertions(+), 7 deletions(-) (limited to 'unsupported') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index cdbafbbb1..44493906d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -51,7 +51,10 @@ struct nested, 1, typename eval -struct PacketConverter { +struct PacketConverter; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl) {} @@ -109,7 +112,33 @@ struct PacketConverter { }; template -struct PacketConverter { +struct PacketConverter { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits::size; + + SrcPacket src1 = m_impl.template packet(index); + SrcPacket src2 = m_impl.template packet(index + 1 * SrcPacketSize); + SrcPacket src3 = m_impl.template packet(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet(index + 3 * SrcPacketSize); + SrcPacket src5 = m_impl.template packet(index + 4 * SrcPacketSize); + SrcPacket src6 = m_impl.template packet(index + 5 * SrcPacketSize); + SrcPacket src7 = m_impl.template packet(index + 6 * SrcPacketSize); + SrcPacket src8 = m_impl.template packet(index + 7 * SrcPacketSize); + TgtPacket result = internal::pcast(src1, src2, src3, src4, src5, src6, src7, src8); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {} diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index c4fe9a798..45456f3ef 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -8,6 +8,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" +#include "random_without_cast_overflow.h" #include @@ -104,12 +105,82 @@ static void test_small_to_big_type_cast() } } +template +static void test_type_cast() { + Tensor ftensor(100, 200); + // Generate random values for a valid cast. + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + ftensor(i, j) = internal::random_without_cast_overflow::value(); + } + } + + Tensor ttensor(100, 200); + ttensor = ftensor.template cast(); + + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + const ToType ref = internal::cast(ftensor(i, j)); + VERIFY_IS_APPROX(ttensor(i, j), ref); + } + } +} + +template +struct test_cast_runner { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + +// Only certain types allow cast from std::complex<>. +template +struct test_cast_runner::IsComplex>::type> { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); + CALL_SUBTEST(test_simple_cast()); + CALL_SUBTEST(test_vectorized_cast()); + CALL_SUBTEST(test_float_to_int_cast()); + CALL_SUBTEST(test_big_to_small_type_cast()); + CALL_SUBTEST(test_small_to_big_type_cast()); + + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner>::run()); + CALL_SUBTEST(test_cast_runner>::run()); + } -- cgit v1.2.3