diff options
author | 2017-10-16 13:40:08 -0700 | |
---|---|---|
committer | 2017-10-16 13:44:52 -0700 | |
commit | 24f9c6e0dbd449624aa1db543550ec412975492e (patch) | |
tree | f58676006e9c78a3d6963adea4a105d7a5254c45 /tensorflow/core/lib/io | |
parent | b1128a402d473cc6a43c99a081446c1b45305dd9 (diff) |
Add support for saving DT_VARIANT tensors in TensorBundle.
Add support for reading Varint64 to InputBuffer.
PiperOrigin-RevId: 172371104
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r-- | tensorflow/core/lib/io/inputbuffer.cc | 26 | ||||
-rw-r--r-- | tensorflow/core/lib/io/inputbuffer.h | 26 | ||||
-rw-r--r-- | tensorflow/core/lib/io/inputbuffer_test.cc | 39 |
3 files changed, 87 insertions, 4 deletions
diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc index 7efe2dc543..4d35af49b2 100644 --- a/tensorflow/core/lib/io/inputbuffer.cc +++ b/tensorflow/core/lib/io/inputbuffer.cc @@ -116,17 +116,35 @@ Status InputBuffer::ReadNBytes(int64 bytes_to_read, char* result, } Status InputBuffer::ReadVarint32Fallback(uint32* result) { + Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); + if (errors::IsDataLoss(s)) { + return errors::DataLoss("Stored data is too large to be a varint32."); + } + return s; +} + +Status InputBuffer::ReadVarint64Fallback(uint64* result) { + Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); + if (errors::IsDataLoss(s)) { + return errors::DataLoss("Stored data is too large to be a varint64."); + } + return s; +} + +template <typename T> +Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { uint8 scratch = 0; - char* p = reinterpret_cast<char*>(&scratch); + auto* p = reinterpret_cast<char*>(&scratch); size_t unused_bytes_read = 0; *result = 0; - for (int shift = 0; shift <= 28; shift += 7) { + for (int index = 0; index < max_bytes; index++) { + int shift = 7 * index; TF_RETURN_IF_ERROR(ReadNBytes(1, p, &unused_bytes_read)); - *result |= (scratch & 127) << shift; + *result |= (static_cast<T>(scratch) & 127) << shift; if (!(scratch & 128)) return Status::OK(); } - return errors::DataLoss("Stored data is too large to be a varint32."); + return errors::DataLoss("Stored data longer than ", max_bytes, " bytes."); } Status InputBuffer::SkipNBytes(int64 bytes_to_skip) { diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h index 94a8cfd39b..b3740f396c 100644 --- a/tensorflow/core/lib/io/inputbuffer.h +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -60,6 +60,9 @@ class InputBuffer { // Reads a single varint32. Status ReadVarint32(uint32* result); + // Reads a single varint64. + Status ReadVarint64(uint64* result); + // Like ReadNBytes() without returning the bytes read. Status SkipNBytes(int64 bytes_to_skip); @@ -82,6 +85,15 @@ class InputBuffer { // Internal slow-path routine used by ReadVarint32(). Status ReadVarint32Fallback(uint32* result); + // Internal slow-path routine used by ReadVarint64(). + Status ReadVarint64Fallback(uint64* result); + + // Helper method for reading a varint which can span at max `max_bytes`. + // If the varint is longer, a DataLoss error status is returned. + // If end of file is reached while reading, OutOfRange error is returned. + template <typename T> + Status ReadVarintFallback(T* result, int max_bytes); + RandomAccessFile* file_; // Not owned int64 file_pos_; // Next position to read from in "file_" size_t size_; // Size of "buf_" @@ -109,6 +121,20 @@ inline Status InputBuffer::ReadVarint32(uint32* result) { } } +// Inlined for performance. +inline Status InputBuffer::ReadVarint64(uint64* result) { + if (pos_ + core::kMaxVarint64Bytes <= limit_) { + // Fast path: directly parse from buffered data. + // Reads strictly from the range [pos_, limit_). + const char* offset = core::GetVarint64Ptr(pos_, limit_, result); + if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); + pos_ = const_cast<char*>(offset); + return Status::OK(); + } else { + return ReadVarint64Fallback(result); + } +} + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc index 6771697a16..6be1f819c2 100644 --- a/tensorflow/core/lib/io/inputbuffer_test.cc +++ b/tensorflow/core/lib/io/inputbuffer_test.cc @@ -329,5 +329,44 @@ TEST(InputBuffer, ReadVarint32) { } } +TEST(InputBuffer, ReadVarint64) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + + // Generates data. + std::vector<uint64> data; + uint64 i = 0; + for (; i < (1U << 10); i += 1) data.push_back(i); + for (; i < (1U << 15); i += 5) data.push_back(i); + for (; i < (1U << 31); i += 164817) data.push_back(i); + for (; i < (1ULL << 63); i += 16481797854795663UL) data.push_back(i); + data.push_back(std::numeric_limits<uint64>::max()); + + // Writes the varints. + { + std::unique_ptr<WritableFile> file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + string varint; + for (uint64 number : data) { + varint.clear(); + core::PutVarint64(&varint, number); + TF_CHECK_OK(file->Append(StringPiece(varint))); + } + } + + for (auto buf_size : BufferSizes()) { + std::unique_ptr<RandomAccessFile> file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + io::InputBuffer in(file.get(), buf_size); + uint64 result = 0; + + for (uint64 expected : data) { + TF_ASSERT_OK(in.ReadVarint64(&result)); + EXPECT_EQ(expected, result); + } + EXPECT_TRUE(errors::IsOutOfRange(in.ReadVarint64(&result))); + } +} + } // namespace } // namespace tensorflow |