diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h index 445248163..ea286fee1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h @@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform<Eigen::half>(uint64_t* state, uint64_t stream) { return result - Eigen::half(1.0f); } +template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Eigen::bfloat16 RandomToTypeUniform<Eigen::bfloat16>(uint64_t* state, uint64_t stream) { + Eigen::bfloat16 result; + // Generate 7 random bits for the mantissa + unsigned rnd = PCG_XSH_RS_generator(state, stream); + result.value = static_cast<uint16_t>(rnd & 0x7fu); + // Set the exponent + result.value |= (static_cast<uint16_t>(127) << 7); + // Return the final result + return result - Eigen::bfloat16(1.0f); +} template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float RandomToTypeUniform<float>(uint64_t* state, uint64_t stream) { |