aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_types.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-13 10:01:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 10:05:43 -0800
commit10d45f9ca118ed37b190140f9310e58f95d4d52c (patch)
tree7e36c299ae372bfb9cc2ff7903ad81c1f8c1ce0a /tensorflow/core/framework/numeric_types.h
parente31f38913d4018c2cee094e05a04833ac96f8b68 (diff)
Make bfloat16 works with complex
PiperOrigin-RevId: 178917043
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r--tensorflow/core/framework/numeric_types.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 569a4c3756..70563d53ef 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -58,6 +58,14 @@ struct bfloat16 {
#endif
}
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val)
+ : bfloat16(val.real()) {}
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val)
+ : bfloat16(static_cast<float>(val.real())) {}
+
template <class T>
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
: bfloat16(static_cast<float>(val)) {}
@@ -129,6 +137,14 @@ struct bfloat16 {
return static_cast<double>(float(*this));
}
+ EIGEN_DEVICE_FUNC explicit operator complex64() const {
+ return complex64(float(*this), float(0.0));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator complex128() const {
+ return complex128(double(*this), double(0.0));
+ }
+
static bfloat16 epsilon() {
bfloat16 x;
x.value = 0x3c00; // 0x1.0p-7