aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/bfloat16_float.cpp
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-22 18:09:00 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-22 18:09:00 -0700
commit1b84f21e321e9daa1efcd4422ae92c1782c5582c (patch)
tree949b6adeb3fcf26ff67b47561e754a3bc99e0640 /test/bfloat16_float.cpp
parent38b91f256be8bf498f0ba9e8dc4fa0abdd7abe70 (diff)
Revert change that made conversion from bfloat16 to {float, double} implicit.
Add roundtrip tests for casting between bfloat16 and complex types.
Diffstat (limited to 'test/bfloat16_float.cpp')
-rw-r--r--test/bfloat16_float.cpp31
1 files changed, 20 insertions, 11 deletions
diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp
index 96341929a..11fc31363 100644
--- a/test/bfloat16_float.cpp
+++ b/test/bfloat16_float.cpp
@@ -41,6 +41,19 @@ void test_truncate(float input, float expected_truncation, float expected_roundi
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
}
+template<typename T>
+ void test_roundtrip() {
+ // Representable T round trip via bfloat16
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(-std::numeric_limits<T>::infinity())), -std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(std::numeric_limits<T>::infinity())), std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-1.0))), T(-1.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.5))), T(-0.5));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.0))), T(-0.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(1.0))), T(1.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.5))), T(0.5));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.0))), T(0.0));
+}
+
void test_conversion()
{
using Eigen::bfloat16_impl::__bfloat16_raw;
@@ -53,9 +66,9 @@ void test_conversion()
VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity.
// Verify round-to-nearest-even behavior.
- float val1 = bfloat16(__bfloat16_raw(0x3c00));
- float val2 = bfloat16(__bfloat16_raw(0x3c01));
- float val3 = bfloat16(__bfloat16_raw(0x3c02));
+ float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00)));
+ float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01)));
+ float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02)));
VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00);
VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);
@@ -106,14 +119,10 @@ void test_conversion()
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
// Representable floats round trip via bfloat16
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity());
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity());
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f);
+ test_roundtrip<float>();
+ test_roundtrip<double>();
+ test_roundtrip<std::complex<float> >();
+ test_roundtrip<std::complex<double> >();
// Truncate test
test_truncate(