diff options
author | 2017-12-13 10:01:47 -0800 | |
---|---|---|
committer | 2017-12-13 10:05:43 -0800 | |
commit | 10d45f9ca118ed37b190140f9310e58f95d4d52c (patch) | |
tree | 7e36c299ae372bfb9cc2ff7903ad81c1f8c1ce0a /tensorflow/core/framework/numeric_types.h | |
parent | e31f38913d4018c2cee094e05a04833ac96f8b68 (diff) |
Make bfloat16 works with complex
PiperOrigin-RevId: 178917043
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r-- | tensorflow/core/framework/numeric_types.h | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 569a4c3756..70563d53ef 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -58,6 +58,14 @@ struct bfloat16 { #endif } + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) + : bfloat16(val.real()) {} + + explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + : bfloat16(static_cast<float>(val.real())) {} + template <class T> explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) : bfloat16(static_cast<float>(val)) {} @@ -129,6 +137,14 @@ struct bfloat16 { return static_cast<double>(float(*this)); } + EIGEN_DEVICE_FUNC explicit operator complex64() const { + return complex64(float(*this), float(0.0)); + } + + EIGEN_DEVICE_FUNC explicit operator complex128() const { + return complex128(double(*this), double(0.0)); + } + static bfloat16 epsilon() { bfloat16 x; x.value = 0x3c00; // 0x1.0p-7 |