aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 15:50:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 15:59:09 -0700
commitb8c86c3bbd8271ed968087f24e7fb704103bc733 (patch)
tree032797b0dfe233d999a158442cb4c95da4d0856c /tensorflow/core/util
parentb56164c72b8f123bfc675f930111af8801fe034f (diff)
Support saving/restoring of string tensors with lengths greater than 2^32.
PiperOrigin-RevId: 214849978
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD1
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc52
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc64
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README3
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001bin0 -> 1080 bytes
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.indexbin0 -> 211 bytes
6 files changed, 100 insertions, 20 deletions
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 648358606c..4d4db86df2 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -64,6 +64,7 @@ cc_library(
tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
+ data = glob(["testdata/**"]),
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index ea8a259d1a..2dcb57a1f9 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -64,27 +64,36 @@ namespace {
// Reads "num_elements" string elements from file[offset, offset+size) into the
// length-N "destination". Discards the original content of "destination".
//
-// Checksums the string lengths (as restored uint32, not varint32 bytes) and
-// string bytes, and stores it into "actual_crc32c".
+// Checksums the string lengths (as restored uint32 or uint64, not varint64
+// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
uint32* actual_crc32c) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
- // Reads "num_elements" varint32's from "buffered_file".
+ // Reads "num_elements" varint64's from "buffered_file".
TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
- std::vector<uint32> string_lengths(num_elements);
+ std::vector<uint64> string_lengths(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
- TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
+ TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
+ if (string_lengths[i] <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
+ sizeof(uint64));
+ }
}
if (offset + size < buffered_file->Tell()) {
return errors::DataLoss("String lengths longer than expected offset ",
offset + size);
}
- *actual_crc32c =
- crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
- sizeof(uint32) * num_elements);
// Reads the length-checksum.
uint32 length_checksum = 0;
@@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
- const uint32 string_length = string_lengths[i];
+ const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
buffer->resize(string_length);
@@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
size_t* bytes_written, uint32* crc32c) {
// On-disk format:
- // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
- // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
+ // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
+ // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
@@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
- DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
- const uint32 elem_size = static_cast<uint32>(elem->size());
-
- core::PutVarint32(&lengths, elem_size);
- *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
- sizeof(uint32));
+ DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
+ const uint64 elem_size = static_cast<uint64>(elem->size());
+
+ core::PutVarint64(&lengths, elem_size);
+ if (elem_size <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
+ *crc32c = crc32c::Extend(*crc32c,
+ reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *crc32c = crc32c::Extend(
+ *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
+ }
}
TF_RETURN_IF_ERROR(out->Append(lengths));
*bytes_written = lengths.size();
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 59c42baa06..9567e4750b 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -39,6 +39,11 @@ string Prefix(const string& prefix) {
return strings::StrCat(testing::TmpDir(), "/", prefix);
}
+string TestdataPrefix(const string& prefix) {
+ return strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/util/tensor_bundle/testdata/", prefix);
+}
+
template <typename T>
Tensor Constant(T v, TensorShape shape) {
Tensor ret(DataTypeToEnum<T>::value, shape);
@@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<qint8>();
}
+TEST(TensorBundleTest, StringTensorsOldFormat) {
+ // Test string tensor bundle made with previous version of code that use
+ // varint32s to store string lengths (we now use varint64s).
+ BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+
+ Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
+ Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
+ Expect<string>(
+ &reader, "strs",
+ test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
+ Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+}
+
TEST(TensorBundleTest, StringTensors) {
+ constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
+ Tensor long_string_tensor(DT_STRING, TensorShape({1}));
+
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
@@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) {
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
+
+ // Requires a 64-bit length.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign(kLongLength, 'd');
+ TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
+
// Mixes in some floats.
TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
TF_ASSERT_OK(writer.Finish());
@@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) {
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
- EXPECT_EQ(
- AllTensorKeys(&reader),
- std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "long_scalar", "scalar",
+ "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
@@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) {
Expect<string>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
+
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+
+ // We don't use the Expect function so we can re-use the
+ // `long_string_tensor` buffer for reading out long_scalar to keep memory
+ // usage reasonable.
+ EXPECT_TRUE(reader.Contains("long_scalar"));
+ DataType dtype;
+ TensorShape shape;
+ TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
+ EXPECT_EQ(DT_STRING, dtype);
+ EXPECT_EQ(TensorShape({1}), shape);
+
+ // Zero-out the string so that we can be sure the new one is read in.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign("");
+
+ // Read long_scalar and check it contains kLongLength 'd's.
+ TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
+ ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
+ EXPECT_EQ(kLongLength, backing_string->length());
+ for (char c : *backing_string) {
+ // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
+ // compiler optimizations.
+ if (c != 'd') {
+ FAIL() << "long_scalar is not full of 'd's as expected.";
+ break;
+ }
+ }
}
}
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
new file mode 100644
index 0000000000..428d3ef79e
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
@@ -0,0 +1,3 @@
+This tensor bundle was generated from cl/214343133, before string tensor
+lengths were written as varint64s. This is here to check backwards
+compatibility between the new code and old checkpoints.
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
new file mode 100644
index 0000000000..23b488e5fe
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
new file mode 100644
index 0000000000..a22a69e6e1
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
Binary files differ