diff options
-rw-r--r-- | Eigen/src/Core/arch/Default/BFloat16.h | 75 | ||||
-rw-r--r-- | test/bfloat16_float.cpp | 17 |
2 files changed, 26 insertions, 66 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index 30c998249..c963ece6c 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -13,16 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef EIGEN_BFLOAT16_H #define EIGEN_BFLOAT16_H -#if __cplusplus > 199711L -#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type() -#else -#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type() -#endif - #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \ template <> \ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \ @@ -34,20 +27,6 @@ namespace Eigen { struct bfloat16; -// explicit conversion operators are no available before C++11 so we first cast -// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly -#if !EIGEN_HAS_CXX11 -namespace internal { -template <typename RealScalar> -struct cast_impl<bfloat16, std::complex<RealScalar> > { - EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const bfloat16 &x) - { - return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x)); - } -}; -} // namespace internal -#endif // EIGEN_HAS_CXX11 - namespace bfloat16_impl { // Make our own __bfloat16_raw definition. @@ -86,66 +65,32 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { explicit EIGEN_DEVICE_FUNC bfloat16(bool b) : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {} + template<class T> explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {} + explicit EIGEN_DEVICE_FUNC bfloat16(float f) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {} + // Following the convention of numpy, converting between complex and // float will lead to loss of imag value. template<typename RealScalar> explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {} + EIGEN_DEVICE_FUNC operator float() const { + return bfloat16_impl::bfloat16_to_float(*this); + } + +#if EIGEN_HAS_CXX11 EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { // +0.0 and -0.0 become false, everything else becomes true. return (value & 0x7fff) != 0; } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const { - return static_cast<signed char>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const { - return static_cast<unsigned char>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const { - return static_cast<short>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const { - return static_cast<unsigned short>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const { - return static_cast<int>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const { - return static_cast<unsigned int>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const { - return static_cast<long>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const { - return static_cast<unsigned long>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const { - return static_cast<long long>(bfloat16_impl::bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const { - return static_cast<unsigned long long>(bfloat16_to_float(*this)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { - return bfloat16_impl::bfloat16_to_float(*this); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { - return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)); - } - template<typename RealScalar> - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { - return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { - return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this)); - } -}; +#endif +}; } // end namespace Eigen namespace std { diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp index 11fc31363..94226e999 100644 --- a/test/bfloat16_float.cpp +++ b/test/bfloat16_float.cpp @@ -84,10 +84,20 @@ void test_conversion() VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000); VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80); - // Conversion to float. + // Conversion to bool + VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(3)), true); + VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.33333f)), true); + VERIFY_IS_EQUAL(bfloat16(-0.0), false); + VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.0)), false); + + // Explicit conversion to float. VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f); VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f); + // Implicit conversion to float + VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x0000)), 0.0f); + VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x3f80)), 1.0f); + // Zero representations VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f)); VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f)); @@ -101,6 +111,11 @@ void test_conversion() denorm = nextafterf(denorm, 1.0f)) { bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm); VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f); + + // Implicit conversion of denormls to bool is correct + VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(denorm)), false); + VERIFY_IS_EQUAL(bfloat16(denorm), false); + if (std::signbit(denorm)) { VERIFY_IS_EQUAL(bf_trunc.value, 0x8000); } else { |