aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-12-13 15:21:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 15:25:12 -0800
commitd80d6de2890112f0013ffede31767907ec3291ca (patch)
tree2133473e0938b551c2ac79e683336952f660c4c7
parentb09be8eff9505486b0f838e2cb281c3ebe8ecfc6 (diff)
Fix bfloat16 serialization of Tensors.
Previously, Python serialization and deserialization used the half_val field of TensorProto, whereas C++ serialization used the int_val field. However, C++ bfloat16 deserialization was always broken, so it was never possible to correctly deserialize a bfloat16 Tensor. The only reason serialization worked at all was because of the generic tensor_contents bytes serialization. PiperOrigin-RevId: 178966536
-rw-r--r--tensorflow/core/framework/tensor.cc42
-rw-r--r--tensorflow/core/framework/tensor.proto4
-rw-r--r--tensorflow/core/framework/tensor_test.cc22
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py19
4 files changed, 70 insertions, 17 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 24b7b08ebc..4f08cdc1d7 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -415,18 +415,10 @@ struct ProtoHelper<qint32> {
template <>
struct ProtoHelper<bfloat16> {
- typedef Helper<float>::RepeatedFieldType FieldType;
- static const bfloat16* Begin(const TensorProto& proto) {
- // TODO: Isn't this wrong, given that int_val is 32 bits long?
- return reinterpret_cast<const bfloat16*>(proto.int_val().data());
- }
- static size_t NumElements(const TensorProto& proto) {
- return proto.int_val().size();
- }
static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
- proto->mutable_int_val()->Reserve(n);
+ proto->mutable_half_val()->Reserve(n);
for (size_t i = 0; i < n; ++i) {
- proto->mutable_int_val()->AddAlreadyReserved(data[i].value);
+ proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
}
}
};
@@ -529,9 +521,9 @@ TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
return buf;
}
-// fp16 is opaque to the protobuf, so we deserialize these identical to uint16
-// but with data stored in half_val instead of int_val (ie., we don't use
-// ProtoHelper<uint16>).
+// fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
+// identical to uint16 but with data stored in half_val instead of int_val (ie.,
+// we don't use ProtoHelper<uint16>).
template <>
TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
int64 n) {
@@ -556,6 +548,30 @@ TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
return buf;
}
+template <>
+TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
+ int64 n) {
+ CHECK_GT(n, 0);
+ Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
+ uint16* data = buf->template base<uint16>();
+ if (data == nullptr) {
+ buf->Unref();
+ return nullptr;
+ }
+ const int64 in_n = in.half_val().size();
+ auto begin = in.half_val().begin();
+ if (n <= in_n) {
+ std::copy_n(begin, n, data);
+ } else if (in_n > 0) {
+ std::copy_n(begin, in_n, data);
+ const uint16 last = *(data + in_n - 1);
+ std::fill_n(data + in_n, n - in_n, last);
+ } else {
+ std::fill_n(data, n, 0);
+ }
+ return buf;
+}
+
// Copies T[n] stored in the buffer "in" into the repeated field in
// "out" corresponding to type T.
template <typename T>
diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto
index 6dab325969..abbf16e810 100644
--- a/tensorflow/core/framework/tensor.proto
+++ b/tensorflow/core/framework/tensor.proto
@@ -40,8 +40,8 @@ message TensorProto {
// be set. The values hold the flattened representation of the tensor in
// row major order.
- // DT_HALF. Note that since protobuf has no int16 type, we'll have some
- // pointless zero padding for each value here.
+ // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
+ // have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];
// DT_FLOAT.
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index cbc921ccd0..1482880428 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -175,6 +175,28 @@ void TestCopies(const Tensor& t) {
}
}
+TEST(Tensor_Half, Simple) {
+ Tensor t(DT_HALF, TensorShape({5, 7}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<Eigen::half>()(a, b) = static_cast<Eigen::half>(a * b);
+ }
+ }
+ TestCopies<Eigen::half>(t);
+}
+
+TEST(Tensor_Bfloat16, Simple) {
+ Tensor t(DT_BFLOAT16, TensorShape({5, 7}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<bfloat16>()(a, b) = static_cast<bfloat16>(a * b);
+ }
+ }
+ TestCopies<bfloat16>(t);
+}
+
TEST(Tensor_Float, Simple) {
Tensor t(DT_FLOAT, TensorShape({10, 20}));
EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20})));
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 68817cc256..030c690167 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -44,7 +44,8 @@ class ConstantTest(test.TestCase):
np_ans = np.array(x)
with self.test_session(use_gpu=False):
tf_ans = ops.convert_to_tensor(x).eval()
- if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ dtype = dtypes_lib.as_dtype(np_ans.dtype)
+ if dtype.is_floating or dtype.is_complex:
self.assertAllClose(np_ans, tf_ans)
else:
self.assertAllEqual(np_ans, tf_ans)
@@ -53,7 +54,8 @@ class ConstantTest(test.TestCase):
np_ans = np.array(x)
with self.test_session(use_gpu=True):
tf_ans = ops.convert_to_tensor(x).eval()
- if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ dtype = dtypes_lib.as_dtype(np_ans.dtype)
+ if dtype.is_floating or dtype.is_complex:
self.assertAllClose(np_ans, tf_ans)
else:
self.assertAllEqual(np_ans, tf_ans)
@@ -62,6 +64,19 @@ class ConstantTest(test.TestCase):
self._testCpu(x)
self._testGpu(x)
+ def testBFloat16(self):
+ bfloat16 = dtypes_lib.bfloat16.as_numpy_dtype
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(bfloat16))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(bfloat16))
+ self._testAll(np.empty((2, 0, 5)).astype(bfloat16))
+
+ def testHalf(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float16))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float16))
+ self._testAll(np.empty((2, 0, 5)).astype(np.float16))
+
def testFloat(self):
self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
self._testAll(