blob: 0068283367acb097807ccab80aa760c2a6cdfa9c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
#include "tensorflow/core/framework/bfloat16.h"
namespace tensorflow {
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
for (; size; p += 2, q++, size--) {
*q = p[1];
}
}
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
for (; size; p++, q += 2, size--) {
q[0] = 0;
q[1] = *p;
}
}
} // end namespace tensorflow
|