From b8c86c3bbd8271ed968087f24e7fb704103bc733 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Sep 2018 15:50:41 -0700 Subject: Support saving/restoring of string tensors with lengths greater than 2^32. PiperOrigin-RevId: 214849978 --- tensorflow/core/util/tensor_bundle/BUILD | 1 + .../core/util/tensor_bundle/tensor_bundle.cc | 52 +++++++++++------ .../core/util/tensor_bundle/tensor_bundle_test.cc | 64 ++++++++++++++++++++- .../testdata/old_string_tensors/README | 3 + .../old_string_tensors/foo.data-00000-of-00001 | Bin 0 -> 1080 bytes .../testdata/old_string_tensors/foo.index | Bin 0 -> 211 bytes 6 files changed, 100 insertions(+), 20 deletions(-) create mode 100644 tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README create mode 100644 tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 create mode 100644 tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index (limited to 'tensorflow/core/util') diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index 648358606c..4d4db86df2 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -64,6 +64,7 @@ cc_library( tf_cc_test( name = "tensor_bundle_test", srcs = ["tensor_bundle_test.cc"], + data = glob(["testdata/**"]), deps = [ ":tensor_bundle", "//tensorflow/core:framework", diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index ea8a259d1a..2dcb57a1f9 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -64,27 +64,36 @@ namespace { // Reads "num_elements" string elements from file[offset, offset+size) into the // length-N "destination". Discards the original content of "destination". // -// Checksums the string lengths (as restored uint32, not varint32 bytes) and -// string bytes, and stores it into "actual_crc32c". +// Checksums the string lengths (as restored uint32 or uint64, not varint64 +// bytes) and string bytes, and stores it into "actual_crc32c". Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, size_t offset, size_t size, string* destination, uint32* actual_crc32c) { if (size == 0) return Status::OK(); CHECK_GT(size, 0); - // Reads "num_elements" varint32's from "buffered_file". + // Reads "num_elements" varint64's from "buffered_file". TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); - std::vector string_lengths(num_elements); + std::vector string_lengths(num_elements); for (size_t i = 0; i < num_elements; ++i) { - TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i])); + TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i])); + if (string_lengths[i] <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast(string_lengths[i]); + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast(&elem_size_uint32), + sizeof(uint32)); + } else { + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast(&string_lengths[i]), + sizeof(uint64)); + } } if (offset + size < buffered_file->Tell()) { return errors::DataLoss("String lengths longer than expected offset ", offset + size); } - *actual_crc32c = - crc32c::Value(reinterpret_cast(string_lengths.data()), - sizeof(uint32) * num_elements); // Reads the length-checksum. uint32 length_checksum = 0; @@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, // Reads the actual string bytes. for (size_t i = 0; i < num_elements; ++i) { - const uint32 string_length = string_lengths[i]; + const uint64 string_length = string_lengths[i]; string* buffer = &destination[i]; buffer->resize(string_length); @@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out, Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, size_t* bytes_written, uint32* crc32c) { // On-disk format: - // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes] - // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes), + // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes] + // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), // the length-checksum, and all the string bytes. DCHECK_EQ(val.dtype(), DT_STRING); const string* strings = GetStringBackingBuffer(val); @@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, *crc32c = 0; for (int64 i = 0; i < val.NumElements(); ++i) { const string* elem = &strings[i]; - DCHECK_EQ(elem->size(), static_cast(elem->size())); - const uint32 elem_size = static_cast(elem->size()); - - core::PutVarint32(&lengths, elem_size); - *crc32c = crc32c::Extend(*crc32c, reinterpret_cast(&elem_size), - sizeof(uint32)); + DCHECK_EQ(elem->size(), static_cast(elem->size())); + const uint64 elem_size = static_cast(elem->size()); + + core::PutVarint64(&lengths, elem_size); + if (elem_size <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast(elem_size); + *crc32c = crc32c::Extend(*crc32c, + reinterpret_cast(&elem_size_uint32), + sizeof(uint32)); + } else { + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast(&elem_size), sizeof(uint64)); + } } TF_RETURN_IF_ERROR(out->Append(lengths)); *bytes_written = lengths.size(); diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 59c42baa06..9567e4750b 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -39,6 +39,11 @@ string Prefix(const string& prefix) { return strings::StrCat(testing::TmpDir(), "/", prefix); } +string TestdataPrefix(const string& prefix) { + return strings::StrCat(testing::TensorFlowSrcRoot(), + "/core/util/tensor_bundle/testdata/", prefix); +} + template Tensor Constant(T v, TensorShape shape) { Tensor ret(DataTypeToEnum::value, shape); @@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) { TestNonStandardShapes(); } +TEST(TensorBundleTest, StringTensorsOldFormat) { + // Test string tensor bundle made with previous version of code that use + // varint32s to store string lengths (we now use varint64s). + BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ(AllTensorKeys(&reader), + std::vector({"floats", "scalar", "string_tensor", "strs"})); + + Expect(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1}))); + Expect(&reader, "scalar", test::AsTensor({"hello"})); + Expect( + &reader, "strs", + test::AsTensor({"hello", "", "x01", string(1 << 10, 'c')})); + Expect(&reader, "floats", Constant_2x3(16.18)); +} + TEST(TensorBundleTest, StringTensors) { + constexpr size_t kLongLength = static_cast(UINT32_MAX) + 1; + Tensor long_string_tensor(DT_STRING, TensorShape({1})); + { BundleWriter writer(Env::Default(), Prefix("foo")); TF_EXPECT_OK(writer.Add("string_tensor", @@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) { TF_EXPECT_OK(writer.Add( "strs", test::AsTensor({"hello", "", "x01", string(1 << 25, 'c')}))); + + // Requires a 64-bit length. + string* backing_string = long_string_tensor.flat().data(); + backing_string->assign(kLongLength, 'd'); + TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor)); + // Mixes in some floats. TF_EXPECT_OK(writer.Add("floats", Constant_2x3(16.18))); TF_ASSERT_OK(writer.Finish()); @@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) { { BundleReader reader(Env::Default(), Prefix("foo")); TF_ASSERT_OK(reader.status()); - EXPECT_EQ( - AllTensorKeys(&reader), - std::vector({"floats", "scalar", "string_tensor", "strs"})); + EXPECT_EQ(AllTensorKeys(&reader), + std::vector({"floats", "long_scalar", "scalar", + "string_tensor", "strs"})); Expect(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1}))); @@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) { Expect( &reader, "strs", test::AsTensor({"hello", "", "x01", string(1 << 25, 'c')})); + Expect(&reader, "floats", Constant_2x3(16.18)); + + // We don't use the Expect function so we can re-use the + // `long_string_tensor` buffer for reading out long_scalar to keep memory + // usage reasonable. + EXPECT_TRUE(reader.Contains("long_scalar")); + DataType dtype; + TensorShape shape; + TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape)); + EXPECT_EQ(DT_STRING, dtype); + EXPECT_EQ(TensorShape({1}), shape); + + // Zero-out the string so that we can be sure the new one is read in. + string* backing_string = long_string_tensor.flat().data(); + backing_string->assign(""); + + // Read long_scalar and check it contains kLongLength 'd's. + TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor)); + ASSERT_EQ(backing_string, long_string_tensor.flat().data()); + EXPECT_EQ(kLongLength, backing_string->length()); + for (char c : *backing_string) { + // Not using ASSERT_EQ('d', c) because this way is twice as fast due to + // compiler optimizations. + if (c != 'd') { + FAIL() << "long_scalar is not full of 'd's as expected."; + break; + } + } } } diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README new file mode 100644 index 0000000000..428d3ef79e --- /dev/null +++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README @@ -0,0 +1,3 @@ +This tensor bundle was generated from cl/214343133, before string tensor +lengths were written as varint64s. This is here to check backwards +compatibility between the new code and old checkpoints. diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 new file mode 100644 index 0000000000..23b488e5fe Binary files /dev/null and b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 differ diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index new file mode 100644 index 0000000000..a22a69e6e1 Binary files /dev/null and b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index differ -- cgit v1.2.3