aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_types.h
blob: dab53cba3e6343c82052c2997c2947b1d11bed59 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_

#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// Disable clang-format to prevent 'FixedPoint' header from being included
// before 'Tensor' header on which it depends.
// clang-format off
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
// clang-format on

#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

// Single precision complex.
typedef std::complex<float> complex64;
// Double precision complex.
typedef std::complex<double> complex128;

// We use Eigen's QInt implementations for our quantized int types.
typedef Eigen::QInt8 qint8;
typedef Eigen::QUInt8 quint8;
typedef Eigen::QInt32 qint32;
typedef Eigen::QInt16 qint16;
typedef Eigen::QUInt16 quint16;

}  // namespace tensorflow




static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    return *reinterpret_cast<tensorflow::bfloat16*>(
        reinterpret_cast<uint16_t*>(&float_val));
#else
    return *reinterpret_cast<tensorflow::bfloat16*>(
        &(reinterpret_cast<uint16_t*>(&float_val)[1]));
#endif
}
    
namespace Eigen {
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
template <>
struct NumTraits<tensorflow::bfloat16>
    : GenericNumTraits<tensorflow::bfloat16> {
  enum {
    IsInteger = 0,
    IsSigned = 1,
    RequireInitialization = 0
  };
  static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
    return FloatToBFloat16(NumTraits<float>::highest());
  }

  static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
    return FloatToBFloat16(NumTraits<float>::lowest());
  }

  static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
    return FloatToBFloat16(NumTraits<float>::infinity());
  }

  static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
    return FloatToBFloat16(NumTraits<float>::quiet_NaN());
  }
};


using ::tensorflow::operator==;
using ::tensorflow::operator!=;

namespace numext {

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
    const tensorflow::bfloat16& x) {
  return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
    const tensorflow::bfloat16& x) {
  return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
    const tensorflow::bfloat16& x) {
  return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
}

}  // namespace numext
}  // namespace Eigen

#if defined(COMPILER_MSVC) && !defined(__clang__)
namespace std {
template <>
struct hash<Eigen::half> {
  std::size_t operator()(const Eigen::half& a) const {
    return static_cast<std::size_t>(a.x);
  }
};
}  // namespace std
#endif  // COMPILER_MSVC

#endif  // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_