aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/bfloat16.cc
blob: 0068283367acb097807ccab80aa760c2a6cdfa9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include "tensorflow/core/framework/bfloat16.h"

namespace tensorflow {

void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
  const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
  uint16_t* q = reinterpret_cast<uint16_t*>(dst);
  for (; size; p += 2, q++, size--) {
    *q = p[1];
  }
}

void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
  const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
  uint16_t* q = reinterpret_cast<uint16_t*>(dst);
  for (; size; p++, q += 2, size--) {
    q[0] = 0;
    q[1] = *p;
  }
}

}  // end namespace tensorflow