aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11
diff options
context:
space:
mode:
authorGravatar Teng Lu <teng.lu@intel.com>2020-06-20 19:16:24 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-06-20 19:16:24 +0000
commit386d809bde475c65b7940f290efe80e6a05878c4 (patch)
treec38e161a53393d15be0ddb02a7a4e22dec738484 /unsupported/Eigen/CXX11
parent6b9c92fe7eff0dedb031cec38004c9c3667f3057 (diff)
Support BFloat16 in Eigen
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h
index 445248163..ea286fee1 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h
@@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform<Eigen::half>(uint64_t* state, uint64_t stream) {
return result - Eigen::half(1.0f);
}
+template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+Eigen::bfloat16 RandomToTypeUniform<Eigen::bfloat16>(uint64_t* state, uint64_t stream) {
+ Eigen::bfloat16 result;
+ // Generate 7 random bits for the mantissa
+ unsigned rnd = PCG_XSH_RS_generator(state, stream);
+ result.value = static_cast<uint16_t>(rnd & 0x7fu);
+ // Set the exponent
+ result.value |= (static_cast<uint16_t>(127) << 7);
+ // Return the final result
+ return result - Eigen::bfloat16(1.0f);
+}
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float RandomToTypeUniform<float>(uint64_t* state, uint64_t stream) {