diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-06-25 14:31:16 -0700 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-06-30 18:53:55 +0000 |
commit | 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch) | |
tree | 5348c34ac0673d09fe97aea29770e7b236e85510 /test/packetmath.cpp | |
parent | 145e51516fdac7b30d22c11c6878c2805fc3d724 (diff) |
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for
`SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the
missing 1:N and 8:1.
We also add casting `Eigen::half` to/from `std::complex<T>`, which
was missing to make it consistent with `Eigen:bfloat16`, and
generalize the overload to work for any complex type.
Tests were added to `basicstuff`, `packetmath`, and
`cxx11_tensor_casts` to test all cast configurations.
Diffstat (limited to 'test/packetmath.cpp')
-rw-r--r-- | test/packetmath.cpp | 185 |
1 files changed, 38 insertions, 147 deletions
diff --git a/test/packetmath.cpp b/test/packetmath.cpp index dbc1d3f5a..7821877db 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -8,8 +8,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include <limits> #include "packetmath_test_shared.h" +#include "random_without_cast_overflow.h" template <typename T> inline T REF_ADD(const T& a, const T& b) { @@ -126,129 +126,6 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, fals static void run() {} }; -// Generates random values that fit in both SrcScalar and TgtScalar without -// overflowing when cast. -template <typename SrcScalar, typename TgtScalar, typename EnableIf = void> -struct random_without_cast_overflow { - static SrcScalar value() { return internal::random<SrcScalar>(); } -}; - -// Widening integer cast signed to unsigned. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && - !NumTraits<TgtScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits || - (std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits && - NumTraits<SrcScalar>::IsSigned))>::type> { - static SrcScalar value() { - SrcScalar a = internal::random<SrcScalar>(); - return a < SrcScalar(0) ? -(a + 1) : a; - } -}; - -// Narrowing integer cast to unsigned. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random<TgtScalar>(); - return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b); - } -}; - -// Narrowing integer cast to signed. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random<TgtScalar>(); - return static_cast<SrcScalar>(b); - } -}; - -// Unsigned to signed narrowing cast. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && - !NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits == - std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { return internal::random<SrcScalar>() / 2; } -}; - -template <typename Scalar> -struct is_floating_point { - enum { value = 0 }; -}; -template <> -struct is_floating_point<float> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<double> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<half> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<bfloat16> { - enum { value = 1 }; -}; - -// Floating-point to integer, full precision. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger && - (std::numeric_limits<TgtScalar>::digits <= - std::numeric_limits<SrcScalar>::digits)>::type> { - static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } -}; - -// Floating-point to integer, narrowing precision. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger && - (std::numeric_limits<TgtScalar>::digits > - std::numeric_limits<SrcScalar>::digits)>::type> { - static SrcScalar value() { - static const int BitShift = std::numeric_limits<TgtScalar>::digits - std::numeric_limits<SrcScalar>::digits; - return static_cast<SrcScalar>(internal::random<TgtScalar>() >> BitShift); - } -}; - -// Floating-point target from integer, re-use above logic. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && is_floating_point<TgtScalar>::value>::type> { - static SrcScalar value() { - return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value()); - } -}; - -// Floating-point narrowing conversion. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && is_floating_point<TgtScalar>::value && - (std::numeric_limits<SrcScalar>::digits > - std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } -}; - template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio> struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true> { static void run() { @@ -266,10 +143,12 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true // Construct a packet of scalars that will not overflow when casting for (int i = 0; i < DataSize; ++i) { - data1[i] = random_without_cast_overflow<SrcScalar, TgtScalar>::value(); + data1[i] = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value(); } - for (int i = 0; i < DataSize; ++i) ref[i] = static_cast<const TgtScalar>(data1[i]); + for (int i = 0; i < DataSize; ++i) { + ref[i] = static_cast<const TgtScalar>(data1[i]); + } pcast_array<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio>::cast(data1, DataSize, data2); @@ -318,21 +197,37 @@ struct test_cast_runner<SrcPacket, TgtScalar, TgtPacket, false, false> { static void run() {} }; +template <typename Scalar, typename Packet, typename EnableIf = void> +struct packetmath_pcast_ops_runner { + static void run() { + test_cast_runner<Packet, float>::run(); + test_cast_runner<Packet, double>::run(); + test_cast_runner<Packet, int8_t>::run(); + test_cast_runner<Packet, uint8_t>::run(); + test_cast_runner<Packet, int16_t>::run(); + test_cast_runner<Packet, uint16_t>::run(); + test_cast_runner<Packet, int32_t>::run(); + test_cast_runner<Packet, uint32_t>::run(); + test_cast_runner<Packet, int64_t>::run(); + test_cast_runner<Packet, uint64_t>::run(); + test_cast_runner<Packet, bool>::run(); + test_cast_runner<Packet, std::complex<float>>::run(); + test_cast_runner<Packet, std::complex<double>>::run(); + test_cast_runner<Packet, half>::run(); + test_cast_runner<Packet, bfloat16>::run(); + } +}; + +// Only some types support cast from std::complex<>. template <typename Scalar, typename Packet> -void packetmath_pcast_ops() { - test_cast_runner<Packet, float>::run(); - test_cast_runner<Packet, double>::run(); - test_cast_runner<Packet, int8_t>::run(); - test_cast_runner<Packet, uint8_t>::run(); - test_cast_runner<Packet, int16_t>::run(); - test_cast_runner<Packet, uint16_t>::run(); - test_cast_runner<Packet, int32_t>::run(); - test_cast_runner<Packet, uint32_t>::run(); - test_cast_runner<Packet, int64_t>::run(); - test_cast_runner<Packet, uint64_t>::run(); - test_cast_runner<Packet, bool>::run(); - test_cast_runner<Packet, half>::run(); -} +struct packetmath_pcast_ops_runner<Scalar, Packet, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> { + static void run() { + test_cast_runner<Packet, std::complex<float>>::run(); + test_cast_runner<Packet, std::complex<double>>::run(); + test_cast_runner<Packet, half>::run(); + test_cast_runner<Packet, bfloat16>::run(); + } +}; template <typename Scalar, typename Packet> void packetmath_boolean_mask_ops() { @@ -356,10 +251,8 @@ void packetmath_boolean_mask_ops() { // Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path // (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1]. -#ifdef EIGEN_PACKET_MATH_SSE_H -template <> -void packetmath_boolean_mask_ops<bool, internal::Packet16b>() {} -#endif +template<> +void packetmath_boolean_mask_ops<bool, typename internal::packet_traits<bool>::type>() {} template <typename Scalar, typename Packet> void packetmath() { @@ -560,7 +453,7 @@ void packetmath() { CHECK_CWISE2_IF(true, internal::pand, internal::pand); packetmath_boolean_mask_ops<Scalar, Packet>(); - packetmath_pcast_ops<Scalar, Packet>(); + packetmath_pcast_ops_runner<Scalar, Packet>::run(); } template <typename Scalar, typename Packet> @@ -975,9 +868,7 @@ EIGEN_DECLARE_TEST(packetmath) { CALL_SUBTEST_11(test::runner<std::complex<float> >::run()); CALL_SUBTEST_12(test::runner<std::complex<double> >::run()); CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>())); -#ifdef EIGEN_PACKET_MATH_SSE_H CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>())); -#endif CALL_SUBTEST_15((packetmath<bfloat16, internal::packet_traits<bfloat16>::type>())); g_first_pass = false; } |