diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-06-25 14:31:16 -0700 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-06-30 18:53:55 +0000 |
commit | 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch) | |
tree | 5348c34ac0673d09fe97aea29770e7b236e85510 | |
parent | 145e51516fdac7b30d22c11c6878c2805fc3d724 (diff) |
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for
`SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the
missing 1:N and 8:1.
We also add casting `Eigen::half` to/from `std::complex<T>`, which
was missing to make it consistent with `Eigen:bfloat16`, and
generalize the overload to work for any complex type.
Tests were added to `basicstuff`, `packetmath`, and
`cxx11_tensor_casts` to test all cast configurations.
-rw-r--r-- | Eigen/src/Core/arch/Default/BFloat16.h | 17 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/Half.h | 12 | ||||
-rw-r--r-- | test/basicstuff.cpp | 92 | ||||
-rw-r--r-- | test/packetmath.cpp | 185 | ||||
-rw-r--r-- | test/random_without_cast_overflow.h | 152 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h | 33 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_casts.cpp | 81 |
7 files changed, 388 insertions, 184 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index c3725d473..99ce99a27 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -65,13 +65,8 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} // Following the convention of numpy, converting between complex and // float will lead to loss of imag value. - // Single precision complex. - typedef std::complex<float> complex64; - // Double precision complex. - typedef std::complex<double> complex128; - explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) - : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {} - explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + template<typename RealScalar> + explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { @@ -114,11 +109,9 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const { - return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const { - return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0)); + template<typename RealScalar> + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { + return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0)); } EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this)); diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index cfd0bdc06..b84cfc7db 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -86,7 +86,7 @@ struct half_base : public __half_raw { #if (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000) EIGEN_DEVICE_FUNC half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} #endif - #endif + #endif #endif }; @@ -133,6 +133,11 @@ struct half : public half_impl::half_base { : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {} explicit EIGEN_DEVICE_FUNC half(float f) : half_impl::half_base(half_impl::float_to_half_rtne(f)) {} + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + template<typename RealScalar> + explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c) + : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { // +0.0 and -0.0 become false, everything else becomes true. @@ -174,6 +179,11 @@ struct half : public half_impl::half_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast<double>(half_impl::half_to_float(*this)); } + + template<typename RealScalar> + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { + return std::complex<RealScalar>(static_cast<RealScalar>(*this), RealScalar(0)); + } }; } // end namespace Eigen diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp index 85af603d8..80fc8a07f 100644 --- a/test/basicstuff.cpp +++ b/test/basicstuff.cpp @@ -10,6 +10,7 @@ #define EIGEN_NO_STATIC_ASSERT #include "main.h" +#include "random_without_cast_overflow.h" template<typename MatrixType> void basicStuff(const MatrixType& m) { @@ -90,7 +91,7 @@ template<typename MatrixType> void basicStuff(const MatrixType& m) Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> cv(rows); rv = square.row(r); cv = square.col(r); - + VERIFY_IS_APPROX(rv, cv.transpose()); if(cols!=1 && rows!=1 && MatrixType::SizeAtCompileTime!=Dynamic) @@ -120,28 +121,28 @@ template<typename MatrixType> void basicStuff(const MatrixType& m) m1 = m2; VERIFY(m1==m2); VERIFY(!(m1!=m2)); - + // check automatic transposition sm2.setZero(); for(Index i=0;i<rows;++i) sm2.col(i) = sm1.row(i); VERIFY_IS_APPROX(sm2,sm1.transpose()); - + sm2.setZero(); for(Index i=0;i<rows;++i) sm2.col(i).noalias() = sm1.row(i); VERIFY_IS_APPROX(sm2,sm1.transpose()); - + sm2.setZero(); for(Index i=0;i<rows;++i) sm2.col(i).noalias() += sm1.row(i); VERIFY_IS_APPROX(sm2,sm1.transpose()); - + sm2.setZero(); for(Index i=0;i<rows;++i) sm2.col(i).noalias() -= sm1.row(i); VERIFY_IS_APPROX(sm2,-sm1.transpose()); - + // check ternary usage { bool b = internal::random<int>(0,10)>5; @@ -194,14 +195,72 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m) VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero()); } -template<int> -void casting() +template<typename SrcScalar, typename TgtScalar> +void casting_test() { - Matrix4f m = Matrix4f::Random(), m2; - Matrix4d n = m.cast<double>(); - VERIFY(m.isApprox(n.cast<float>())); - m2 = m.cast<float>(); // check the specialization when NewType == Type - VERIFY(m.isApprox(m2)); + Matrix<SrcScalar,4,4> m; + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value(); + } + } + Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>(); + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j))); + } + } +} + +template<typename SrcScalar, typename EnableIf = void> +struct casting_test_runner { + static void run() { + casting_test<SrcScalar, bool>(); + casting_test<SrcScalar, int8_t>(); + casting_test<SrcScalar, uint8_t>(); + casting_test<SrcScalar, int16_t>(); + casting_test<SrcScalar, uint16_t>(); + casting_test<SrcScalar, int32_t>(); + casting_test<SrcScalar, uint32_t>(); + casting_test<SrcScalar, int64_t>(); + casting_test<SrcScalar, uint64_t>(); + casting_test<SrcScalar, half>(); + casting_test<SrcScalar, bfloat16>(); + casting_test<SrcScalar, float>(); + casting_test<SrcScalar, double>(); + casting_test<SrcScalar, std::complex<float>>(); + casting_test<SrcScalar, std::complex<double>>(); + } +}; + +template<typename SrcScalar> +struct casting_test_runner<SrcScalar, typename internal::enable_if<(NumTraits<SrcScalar>::IsComplex)>::type> +{ + static void run() { + // Only a few casts from std::complex<T> are defined. + casting_test<SrcScalar, half>(); + casting_test<SrcScalar, bfloat16>(); + casting_test<SrcScalar, std::complex<float>>(); + casting_test<SrcScalar, std::complex<double>>(); + } +}; + +void casting_all() { + casting_test_runner<bool>::run(); + casting_test_runner<int8_t>::run(); + casting_test_runner<uint8_t>::run(); + casting_test_runner<int16_t>::run(); + casting_test_runner<uint16_t>::run(); + casting_test_runner<int32_t>::run(); + casting_test_runner<uint32_t>::run(); + casting_test_runner<int64_t>::run(); + casting_test_runner<uint64_t>::run(); + casting_test_runner<half>::run(); + casting_test_runner<bfloat16>::run(); + casting_test_runner<float>::run(); + casting_test_runner<double>::run(); + casting_test_runner<std::complex<float>>::run(); + casting_test_runner<std::complex<double>>::run(); } template <typename Scalar> @@ -210,12 +269,12 @@ void fixedSizeMatrixConstruction() Scalar raw[4]; for(int k=0; k<4; ++k) raw[k] = internal::random<Scalar>(); - + { Matrix<Scalar,4,1> m(raw); Array<Scalar,4,1> a(raw); for(int k=0; k<4; ++k) VERIFY(m(k) == raw[k]); - for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]); + for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]); VERIFY_IS_EQUAL(m,(Matrix<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3]))); VERIFY((a==(Array<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3]))).all()); } @@ -277,6 +336,7 @@ EIGEN_DECLARE_TEST(basicstuff) CALL_SUBTEST_5( basicStuff(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( basicStuff(Matrix<float, 100, 100>()) ); CALL_SUBTEST_7( basicStuff(Matrix<long double,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE),internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); + CALL_SUBTEST_8( casting_all() ); CALL_SUBTEST_3( basicStuffComplex(MatrixXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_5( basicStuffComplex(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); @@ -288,6 +348,4 @@ EIGEN_DECLARE_TEST(basicstuff) CALL_SUBTEST_1(fixedSizeMatrixConstruction<int>()); CALL_SUBTEST_1(fixedSizeMatrixConstruction<long int>()); CALL_SUBTEST_1(fixedSizeMatrixConstruction<std::ptrdiff_t>()); - - CALL_SUBTEST_2(casting<0>()); } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index dbc1d3f5a..7821877db 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -8,8 +8,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include <limits> #include "packetmath_test_shared.h" +#include "random_without_cast_overflow.h" template <typename T> inline T REF_ADD(const T& a, const T& b) { @@ -126,129 +126,6 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, fals static void run() {} }; -// Generates random values that fit in both SrcScalar and TgtScalar without -// overflowing when cast. -template <typename SrcScalar, typename TgtScalar, typename EnableIf = void> -struct random_without_cast_overflow { - static SrcScalar value() { return internal::random<SrcScalar>(); } -}; - -// Widening integer cast signed to unsigned. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && - !NumTraits<TgtScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits || - (std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits && - NumTraits<SrcScalar>::IsSigned))>::type> { - static SrcScalar value() { - SrcScalar a = internal::random<SrcScalar>(); - return a < SrcScalar(0) ? -(a + 1) : a; - } -}; - -// Narrowing integer cast to unsigned. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random<TgtScalar>(); - return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b); - } -}; - -// Narrowing integer cast to signed. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random<TgtScalar>(); - return static_cast<SrcScalar>(b); - } -}; - -// Unsigned to signed narrowing cast. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && - !NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned && - (std::numeric_limits<SrcScalar>::digits == - std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { return internal::random<SrcScalar>() / 2; } -}; - -template <typename Scalar> -struct is_floating_point { - enum { value = 0 }; -}; -template <> -struct is_floating_point<float> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<double> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<half> { - enum { value = 1 }; -}; -template <> -struct is_floating_point<bfloat16> { - enum { value = 1 }; -}; - -// Floating-point to integer, full precision. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger && - (std::numeric_limits<TgtScalar>::digits <= - std::numeric_limits<SrcScalar>::digits)>::type> { - static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } -}; - -// Floating-point to integer, narrowing precision. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger && - (std::numeric_limits<TgtScalar>::digits > - std::numeric_limits<SrcScalar>::digits)>::type> { - static SrcScalar value() { - static const int BitShift = std::numeric_limits<TgtScalar>::digits - std::numeric_limits<SrcScalar>::digits; - return static_cast<SrcScalar>(internal::random<TgtScalar>() >> BitShift); - } -}; - -// Floating-point target from integer, re-use above logic. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && is_floating_point<TgtScalar>::value>::type> { - static SrcScalar value() { - return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value()); - } -}; - -// Floating-point narrowing conversion. -template <typename SrcScalar, typename TgtScalar> -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if<is_floating_point<SrcScalar>::value && is_floating_point<TgtScalar>::value && - (std::numeric_limits<SrcScalar>::digits > - std::numeric_limits<TgtScalar>::digits)>::type> { - static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } -}; - template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio> struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true> { static void run() { @@ -266,10 +143,12 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true // Construct a packet of scalars that will not overflow when casting for (int i = 0; i < DataSize; ++i) { - data1[i] = random_without_cast_overflow<SrcScalar, TgtScalar>::value(); + data1[i] = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value(); } - for (int i = 0; i < DataSize; ++i) ref[i] = static_cast<const TgtScalar>(data1[i]); + for (int i = 0; i < DataSize; ++i) { + ref[i] = static_cast<const TgtScalar>(data1[i]); + } pcast_array<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio>::cast(data1, DataSize, data2); @@ -318,21 +197,37 @@ struct test_cast_runner<SrcPacket, TgtScalar, TgtPacket, false, false> { static void run() {} }; +template <typename Scalar, typename Packet, typename EnableIf = void> +struct packetmath_pcast_ops_runner { + static void run() { + test_cast_runner<Packet, float>::run(); + test_cast_runner<Packet, double>::run(); + test_cast_runner<Packet, int8_t>::run(); + test_cast_runner<Packet, uint8_t>::run(); + test_cast_runner<Packet, int16_t>::run(); + test_cast_runner<Packet, uint16_t>::run(); + test_cast_runner<Packet, int32_t>::run(); + test_cast_runner<Packet, uint32_t>::run(); + test_cast_runner<Packet, int64_t>::run(); + test_cast_runner<Packet, uint64_t>::run(); + test_cast_runner<Packet, bool>::run(); + test_cast_runner<Packet, std::complex<float>>::run(); + test_cast_runner<Packet, std::complex<double>>::run(); + test_cast_runner<Packet, half>::run(); + test_cast_runner<Packet, bfloat16>::run(); + } +}; + +// Only some types support cast from std::complex<>. template <typename Scalar, typename Packet> -void packetmath_pcast_ops() { - test_cast_runner<Packet, float>::run(); - test_cast_runner<Packet, double>::run(); - test_cast_runner<Packet, int8_t>::run(); - test_cast_runner<Packet, uint8_t>::run(); - test_cast_runner<Packet, int16_t>::run(); - test_cast_runner<Packet, uint16_t>::run(); - test_cast_runner<Packet, int32_t>::run(); - test_cast_runner<Packet, uint32_t>::run(); - test_cast_runner<Packet, int64_t>::run(); - test_cast_runner<Packet, uint64_t>::run(); - test_cast_runner<Packet, bool>::run(); - test_cast_runner<Packet, half>::run(); -} +struct packetmath_pcast_ops_runner<Scalar, Packet, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> { + static void run() { + test_cast_runner<Packet, std::complex<float>>::run(); + test_cast_runner<Packet, std::complex<double>>::run(); + test_cast_runner<Packet, half>::run(); + test_cast_runner<Packet, bfloat16>::run(); + } +}; template <typename Scalar, typename Packet> void packetmath_boolean_mask_ops() { @@ -356,10 +251,8 @@ void packetmath_boolean_mask_ops() { // Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path // (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1]. -#ifdef EIGEN_PACKET_MATH_SSE_H -template <> -void packetmath_boolean_mask_ops<bool, internal::Packet16b>() {} -#endif +template<> +void packetmath_boolean_mask_ops<bool, typename internal::packet_traits<bool>::type>() {} template <typename Scalar, typename Packet> void packetmath() { @@ -560,7 +453,7 @@ void packetmath() { CHECK_CWISE2_IF(true, internal::pand, internal::pand); packetmath_boolean_mask_ops<Scalar, Packet>(); - packetmath_pcast_ops<Scalar, Packet>(); + packetmath_pcast_ops_runner<Scalar, Packet>::run(); } template <typename Scalar, typename Packet> @@ -975,9 +868,7 @@ EIGEN_DECLARE_TEST(packetmath) { CALL_SUBTEST_11(test::runner<std::complex<float> >::run()); CALL_SUBTEST_12(test::runner<std::complex<double> >::run()); CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>())); -#ifdef EIGEN_PACKET_MATH_SSE_H CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>())); -#endif CALL_SUBTEST_15((packetmath<bfloat16, internal::packet_traits<bfloat16>::type>())); g_first_pass = false; } diff --git a/test/random_without_cast_overflow.h b/test/random_without_cast_overflow.h new file mode 100644 index 000000000..000345110 --- /dev/null +++ b/test/random_without_cast_overflow.h @@ -0,0 +1,152 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 C. Antonio Sanchez <cantonios@google.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Utilities for generating random numbers without overflows, which might +// otherwise result in undefined behavior. + +namespace Eigen { +namespace internal { + +// Default implementation assuming SrcScalar fits into TgtScalar. +template <typename SrcScalar, typename TgtScalar, typename EnableIf = void> +struct random_without_cast_overflow { + static SrcScalar value() { return internal::random<SrcScalar>(); } +}; + +// Signed to unsigned integer widening cast. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && + !NumTraits<TgtScalar>::IsSigned && + (std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits || + (std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits && + NumTraits<SrcScalar>::IsSigned))>::type> { + static SrcScalar value() { + SrcScalar a = internal::random<SrcScalar>(); + return a < SrcScalar(0) ? -(a + 1) : a; + } +}; + +// Integer to unsigned narrowing cast. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned && + (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { + static SrcScalar value() { + TgtScalar b = internal::random<TgtScalar>(); + return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b); + } +}; + +// Integer to signed narrowing cast. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned && + (std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> { + static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } +}; + +// Unsigned to signed integer narrowing cast. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && + !NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned && + (std::numeric_limits<SrcScalar>::digits == + std::numeric_limits<TgtScalar>::digits)>::type> { + static SrcScalar value() { return internal::random<SrcScalar>() / 2; } +}; + +// Floating-point to integer, full precision. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsInteger && + (std::numeric_limits<TgtScalar>::digits <= std::numeric_limits<SrcScalar>::digits)>::type> { + static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } +}; + +// Floating-point to integer, narrowing precision. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsInteger && + (std::numeric_limits<TgtScalar>::digits > std::numeric_limits<SrcScalar>::digits)>::type> { + static SrcScalar value() { + // NOTE: internal::random<T>() is limited by RAND_MAX, so random<int64_t> is always within that range. + // This prevents us from simply shifting bits, which would result in only 0 or -1. + // Instead, keep least-significant K bits and sign. + static const TgtScalar KeepMask = (static_cast<TgtScalar>(1) << std::numeric_limits<SrcScalar>::digits) - 1; + const TgtScalar a = internal::random<TgtScalar>(); + return static_cast<SrcScalar>(a > TgtScalar(0) ? (a & KeepMask) : -(a & KeepMask)); + } +}; + +// Integer to floating-point, re-use above logic. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && !NumTraits<TgtScalar>::IsInteger && + !NumTraits<TgtScalar>::IsComplex>::type> { + static SrcScalar value() { + return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value()); + } +}; + +// Floating-point narrowing conversion. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<!NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex && + !NumTraits<TgtScalar>::IsInteger && !NumTraits<TgtScalar>::IsComplex && + (std::numeric_limits<SrcScalar>::digits > + std::numeric_limits<TgtScalar>::digits)>::type> { + static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); } +}; + +// Complex to non-complex. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<NumTraits<SrcScalar>::IsComplex && !NumTraits<TgtScalar>::IsComplex>::type> { + typedef typename NumTraits<SrcScalar>::Real SrcReal; + static SrcScalar value() { return SrcScalar(random_without_cast_overflow<SrcReal, TgtScalar>::value(), 0); } +}; + +// Non-complex to complex. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<!NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsComplex>::type> { + typedef typename NumTraits<TgtScalar>::Real TgtReal; + static SrcScalar value() { return random_without_cast_overflow<SrcScalar, TgtReal>::value(); } +}; + +// Complex to complex. +template <typename SrcScalar, typename TgtScalar> +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if<NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsComplex>::type> { + typedef typename NumTraits<SrcScalar>::Real SrcReal; + typedef typename NumTraits<TgtScalar>::Real TgtReal; + static SrcScalar value() { + return SrcScalar(random_without_cast_overflow<SrcReal, TgtReal>::value(), + random_without_cast_overflow<SrcReal, TgtReal>::value()); + } +}; + +} // namespace internal +} // namespace Eigen diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index cdbafbbb1..44493906d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -51,7 +51,10 @@ struct nested<TensorConversionOp<TargetType, XprType>, 1, typename eval<TensorCo template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio> -struct PacketConverter { +struct PacketConverter; + +template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 1> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl) {} @@ -109,7 +112,33 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> { }; template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> -struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> { +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 8, 1> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template<int LoadMode, typename Index> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size; + + SrcPacket src1 = m_impl.template packet<LoadMode>(index); + SrcPacket src2 = m_impl.template packet<LoadMode>(index + 1 * SrcPacketSize); + SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize); + SrcPacket src5 = m_impl.template packet<LoadMode>(index + 4 * SrcPacketSize); + SrcPacket src6 = m_impl.template packet<LoadMode>(index + 5 * SrcPacketSize); + SrcPacket src7 = m_impl.template packet<LoadMode>(index + 6 * SrcPacketSize); + SrcPacket src8 = m_impl.template packet<LoadMode>(index + 7 * SrcPacketSize); + TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4, src5, src6, src7, src8); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; + +template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int TgtCoeffRatio> +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, TgtCoeffRatio> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {} diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index c4fe9a798..45456f3ef 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -8,6 +8,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" +#include "random_without_cast_overflow.h" #include <Eigen/CXX11/Tensor> @@ -104,12 +105,82 @@ static void test_small_to_big_type_cast() } } +template <typename FromType, typename ToType> +static void test_type_cast() { + Tensor<FromType, 2> ftensor(100, 200); + // Generate random values for a valid cast. + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value(); + } + } + + Tensor<ToType, 2> ttensor(100, 200); + ttensor = ftensor.template cast<ToType>(); + + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j)); + VERIFY_IS_APPROX(ttensor(i, j), ref); + } + } +} + +template<typename Scalar, typename EnableIf = void> +struct test_cast_runner { + static void run() { + test_type_cast<Scalar, bool>(); + test_type_cast<Scalar, int8_t>(); + test_type_cast<Scalar, int16_t>(); + test_type_cast<Scalar, int32_t>(); + test_type_cast<Scalar, int64_t>(); + test_type_cast<Scalar, uint8_t>(); + test_type_cast<Scalar, uint16_t>(); + test_type_cast<Scalar, uint32_t>(); + test_type_cast<Scalar, uint64_t>(); + test_type_cast<Scalar, half>(); + test_type_cast<Scalar, bfloat16>(); + test_type_cast<Scalar, float>(); + test_type_cast<Scalar, double>(); + test_type_cast<Scalar, std::complex<float>>(); + test_type_cast<Scalar, std::complex<double>>(); + } +}; + +// Only certain types allow cast from std::complex<>. +template<typename Scalar> +struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> { + static void run() { + test_type_cast<Scalar, half>(); + test_type_cast<Scalar, bfloat16>(); + test_type_cast<Scalar, std::complex<float>>(); + test_type_cast<Scalar, std::complex<double>>(); + } +}; + EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); + CALL_SUBTEST(test_simple_cast()); + CALL_SUBTEST(test_vectorized_cast()); + CALL_SUBTEST(test_float_to_int_cast()); + CALL_SUBTEST(test_big_to_small_type_cast()); + CALL_SUBTEST(test_small_to_big_type_cast()); + + CALL_SUBTEST(test_cast_runner<bool>::run()); + CALL_SUBTEST(test_cast_runner<int8_t>::run()); + CALL_SUBTEST(test_cast_runner<int16_t>::run()); + CALL_SUBTEST(test_cast_runner<int32_t>::run()); + CALL_SUBTEST(test_cast_runner<int64_t>::run()); + CALL_SUBTEST(test_cast_runner<uint8_t>::run()); + CALL_SUBTEST(test_cast_runner<uint16_t>::run()); + CALL_SUBTEST(test_cast_runner<uint32_t>::run()); + CALL_SUBTEST(test_cast_runner<uint64_t>::run()); + CALL_SUBTEST(test_cast_runner<half>::run()); + CALL_SUBTEST(test_cast_runner<bfloat16>::run()); + CALL_SUBTEST(test_cast_runner<float>::run()); + CALL_SUBTEST(test_cast_runner<double>::run()); + CALL_SUBTEST(test_cast_runner<std::complex<float>>::run()); + CALL_SUBTEST(test_cast_runner<std::complex<double>>::run()); + } |