diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-11-17 15:32:44 -0800 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-11-18 20:32:35 +0000 |
commit | 17268b155d54422f1294130c0fb8c178757d911a (patch) | |
tree | 2be3d541729f3e9be6a180a58270bae10156df4f /unsupported/test | |
parent | 41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed (diff) |
Add bit_cast for half/bfloat to/from uint16_t, fix TensorRandom
The existing `TensorRandom.h` implementation makes the assumption that
`half` (`bfloat16`) has a `uint16_t` member `x` (`value`), which is not
always true. This currently fails on arm64, where `x` has type `__fp16`.
Added `bit_cast` specializations to allow casting to/from `uint16_t`
for both `half` and `bfloat16`. Also added tests in
`half_float`, `bfloat16_float`, and `cxx11_tensor_random` to catch
these errors in the future.
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_random.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/unsupported/test/cxx11_tensor_random.cpp b/unsupported/test/cxx11_tensor_random.cpp index 4740d5811..b9d4c5584 100644 --- a/unsupported/test/cxx11_tensor_random.cpp +++ b/unsupported/test/cxx11_tensor_random.cpp @@ -11,9 +11,10 @@ #include <Eigen/CXX11/Tensor> +template<typename Scalar> static void test_default() { - Tensor<float, 1> vec(6); + Tensor<Scalar, 1> vec(6); vec.setRandom(); // Fixme: we should check that the generated numbers follow a uniform @@ -23,10 +24,11 @@ static void test_default() } } +template<typename Scalar> static void test_normal() { - Tensor<float, 1> vec(6); - vec.setRandom<Eigen::internal::NormalRandomGenerator<float>>(); + Tensor<Scalar, 1> vec(6); + vec.template setRandom<Eigen::internal::NormalRandomGenerator<Scalar>>(); // Fixme: we should check that the generated numbers follow a gaussian // distribution instead. @@ -72,7 +74,13 @@ static void test_custom() EIGEN_DECLARE_TEST(cxx11_tensor_random) { - CALL_SUBTEST(test_default()); - CALL_SUBTEST(test_normal()); + CALL_SUBTEST((test_default<float>())); + CALL_SUBTEST((test_normal<float>())); + CALL_SUBTEST((test_default<double>())); + CALL_SUBTEST((test_normal<double>())); + CALL_SUBTEST((test_default<Eigen::half>())); + CALL_SUBTEST((test_normal<Eigen::half>())); + CALL_SUBTEST((test_default<Eigen::bfloat16>())); + CALL_SUBTEST((test_normal<Eigen::bfloat16>())); CALL_SUBTEST(test_custom()); } |