diff options
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r-- | tensorflow/core/framework/numeric_types.h | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 99a5d0a054..4c38fbbe59 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ #include <complex> - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" // Disable clang-format to prevent 'FixedPoint' header from being included // before 'Tensor' header on which it depends. @@ -43,12 +42,47 @@ typedef Eigen::QUInt16 quint16; } // namespace tensorflow + + + +static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return *reinterpret_cast<tensorflow::bfloat16*>( + reinterpret_cast<uint16_t*>(&float_val)); +#else + return *reinterpret_cast<tensorflow::bfloat16*>( + &(reinterpret_cast<uint16_t*>(&float_val)[1])); +#endif +} + namespace Eigen { -// TOOD(xpan): We probably need to overwrite more methods to have correct eigen -// behavior. E.g. loest(), is_integer, etc. See NumTraits.h in eigen. +// TODO(xpan): We probably need to overwrite more methods to have correct eigen +// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen. template <> struct NumTraits<tensorflow::bfloat16> - : GenericNumTraits<tensorflow::bfloat16> {}; + : GenericNumTraits<tensorflow::bfloat16> { + enum { + IsInteger = 0, + IsSigned = 1, + RequireInitialization = 0 + }; + static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() { + return FloatToBFloat16(NumTraits<float>::highest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() { + return FloatToBFloat16(NumTraits<float>::lowest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() { + return FloatToBFloat16(NumTraits<float>::infinity()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() { + return FloatToBFloat16(NumTraits<float>::quiet_NaN()); + } +}; + using ::tensorflow::operator==; using ::tensorflow::operator!=; |