From 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 25 Jun 2020 14:31:16 -0700 Subject: Fix tensor casts for large packets and casts to/from std::complex The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations. --- test/random_without_cast_overflow.h | 152 ++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 test/random_without_cast_overflow.h (limited to 'test/random_without_cast_overflow.h') diff --git a/test/random_without_cast_overflow.h b/test/random_without_cast_overflow.h new file mode 100644 index 000000000..000345110 --- /dev/null +++ b/test/random_without_cast_overflow.h @@ -0,0 +1,152 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 C. Antonio Sanchez +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Utilities for generating random numbers without overflows, which might +// otherwise result in undefined behavior. + +namespace Eigen { +namespace internal { + +// Default implementation assuming SrcScalar fits into TgtScalar. +template +struct random_without_cast_overflow { + static SrcScalar value() { return internal::random(); } +}; + +// Signed to unsigned integer widening cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && NumTraits::IsInteger && + !NumTraits::IsSigned && + (std::numeric_limits::digits < std::numeric_limits::digits || + (std::numeric_limits::digits == std::numeric_limits::digits && + NumTraits::IsSigned))>::type> { + static SrcScalar value() { + SrcScalar a = internal::random(); + return a < SrcScalar(0) ? -(a + 1) : a; + } +}; + +// Integer to unsigned narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits::IsInteger && NumTraits::IsInteger && !NumTraits::IsSigned && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { + TgtScalar b = internal::random(); + return static_cast(b < TgtScalar(0) ? -(b + 1) : b); + } +}; + +// Integer to signed narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits::IsInteger && NumTraits::IsInteger && NumTraits::IsSigned && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Unsigned to signed integer narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && NumTraits::IsInteger && + !NumTraits::IsSigned && NumTraits::IsSigned && + (std::numeric_limits::digits == + std::numeric_limits::digits)>::type> { + static SrcScalar value() { return internal::random() / 2; } +}; + +// Floating-point to integer, full precision. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits::IsInteger && !NumTraits::IsComplex && NumTraits::IsInteger && + (std::numeric_limits::digits <= std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Floating-point to integer, narrowing precision. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits::IsInteger && !NumTraits::IsComplex && NumTraits::IsInteger && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { + // NOTE: internal::random() is limited by RAND_MAX, so random is always within that range. + // This prevents us from simply shifting bits, which would result in only 0 or -1. + // Instead, keep least-significant K bits and sign. + static const TgtScalar KeepMask = (static_cast(1) << std::numeric_limits::digits) - 1; + const TgtScalar a = internal::random(); + return static_cast(a > TgtScalar(0) ? (a & KeepMask) : -(a & KeepMask)); + } +}; + +// Integer to floating-point, re-use above logic. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && !NumTraits::IsInteger && + !NumTraits::IsComplex>::type> { + static SrcScalar value() { + return static_cast(random_without_cast_overflow::value()); + } +}; + +// Floating-point narrowing conversion. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && !NumTraits::IsComplex && + !NumTraits::IsInteger && !NumTraits::IsComplex && + (std::numeric_limits::digits > + std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Complex to non-complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && !NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real SrcReal; + static SrcScalar value() { return SrcScalar(random_without_cast_overflow::value(), 0); } +}; + +// Non-complex to complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real TgtReal; + static SrcScalar value() { return random_without_cast_overflow::value(); } +}; + +// Complex to complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real SrcReal; + typedef typename NumTraits::Real TgtReal; + static SrcScalar value() { + return SrcScalar(random_without_cast_overflow::value(), + random_without_cast_overflow::value()); + } +}; + +} // namespace internal +} // namespace Eigen -- cgit v1.2.3