aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-11-17 15:32:44 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2020-11-18 20:32:35 +0000
commit17268b155d54422f1294130c0fb8c178757d911a (patch)
tree2be3d541729f3e9be6a180a58270bae10156df4f /unsupported/test
parent41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed (diff)
Add bit_cast for half/bfloat to/from uint16_t, fix TensorRandom
The existing `TensorRandom.h` implementation makes the assumption that `half` (`bfloat16`) has a `uint16_t` member `x` (`value`), which is not always true. This currently fails on arm64, where `x` has type `__fp16`. Added `bit_cast` specializations to allow casting to/from `uint16_t` for both `half` and `bfloat16`. Also added tests in `half_float`, `bfloat16_float`, and `cxx11_tensor_random` to catch these errors in the future.
Diffstat (limited to 'unsupported/test')
-rw-r--r--unsupported/test/cxx11_tensor_random.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/unsupported/test/cxx11_tensor_random.cpp b/unsupported/test/cxx11_tensor_random.cpp
index 4740d5811..b9d4c5584 100644
--- a/unsupported/test/cxx11_tensor_random.cpp
+++ b/unsupported/test/cxx11_tensor_random.cpp
@@ -11,9 +11,10 @@
#include <Eigen/CXX11/Tensor>
+template<typename Scalar>
static void test_default()
{
- Tensor<float, 1> vec(6);
+ Tensor<Scalar, 1> vec(6);
vec.setRandom();
// Fixme: we should check that the generated numbers follow a uniform
@@ -23,10 +24,11 @@ static void test_default()
}
}
+template<typename Scalar>
static void test_normal()
{
- Tensor<float, 1> vec(6);
- vec.setRandom<Eigen::internal::NormalRandomGenerator<float>>();
+ Tensor<Scalar, 1> vec(6);
+ vec.template setRandom<Eigen::internal::NormalRandomGenerator<Scalar>>();
// Fixme: we should check that the generated numbers follow a gaussian
// distribution instead.
@@ -72,7 +74,13 @@ static void test_custom()
EIGEN_DECLARE_TEST(cxx11_tensor_random)
{
- CALL_SUBTEST(test_default());
- CALL_SUBTEST(test_normal());
+ CALL_SUBTEST((test_default<float>()));
+ CALL_SUBTEST((test_normal<float>()));
+ CALL_SUBTEST((test_default<double>()));
+ CALL_SUBTEST((test_normal<double>()));
+ CALL_SUBTEST((test_default<Eigen::half>()));
+ CALL_SUBTEST((test_normal<Eigen::half>()));
+ CALL_SUBTEST((test_default<Eigen::bfloat16>()));
+ CALL_SUBTEST((test_normal<Eigen::bfloat16>()));
CALL_SUBTEST(test_custom());
}