aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2017-10-16 13:40:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 13:44:52 -0700
commit24f9c6e0dbd449624aa1db543550ec412975492e (patch)
treef58676006e9c78a3d6963adea4a105d7a5254c45 /tensorflow/core/lib/io
parentb1128a402d473cc6a43c99a081446c1b45305dd9 (diff)
Add support for saving DT_VARIANT tensors in TensorBundle.
Add support for reading Varint64 to InputBuffer. PiperOrigin-RevId: 172371104
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r--tensorflow/core/lib/io/inputbuffer.cc26
-rw-r--r--tensorflow/core/lib/io/inputbuffer.h26
-rw-r--r--tensorflow/core/lib/io/inputbuffer_test.cc39
3 files changed, 87 insertions, 4 deletions
diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc
index 7efe2dc543..4d35af49b2 100644
--- a/tensorflow/core/lib/io/inputbuffer.cc
+++ b/tensorflow/core/lib/io/inputbuffer.cc
@@ -116,17 +116,35 @@ Status InputBuffer::ReadNBytes(int64 bytes_to_read, char* result,
}
Status InputBuffer::ReadVarint32Fallback(uint32* result) {
+ Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes);
+ if (errors::IsDataLoss(s)) {
+ return errors::DataLoss("Stored data is too large to be a varint32.");
+ }
+ return s;
+}
+
+Status InputBuffer::ReadVarint64Fallback(uint64* result) {
+ Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes);
+ if (errors::IsDataLoss(s)) {
+ return errors::DataLoss("Stored data is too large to be a varint64.");
+ }
+ return s;
+}
+
+template <typename T>
+Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) {
uint8 scratch = 0;
- char* p = reinterpret_cast<char*>(&scratch);
+ auto* p = reinterpret_cast<char*>(&scratch);
size_t unused_bytes_read = 0;
*result = 0;
- for (int shift = 0; shift <= 28; shift += 7) {
+ for (int index = 0; index < max_bytes; index++) {
+ int shift = 7 * index;
TF_RETURN_IF_ERROR(ReadNBytes(1, p, &unused_bytes_read));
- *result |= (scratch & 127) << shift;
+ *result |= (static_cast<T>(scratch) & 127) << shift;
if (!(scratch & 128)) return Status::OK();
}
- return errors::DataLoss("Stored data is too large to be a varint32.");
+ return errors::DataLoss("Stored data longer than ", max_bytes, " bytes.");
}
Status InputBuffer::SkipNBytes(int64 bytes_to_skip) {
diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h
index 94a8cfd39b..b3740f396c 100644
--- a/tensorflow/core/lib/io/inputbuffer.h
+++ b/tensorflow/core/lib/io/inputbuffer.h
@@ -60,6 +60,9 @@ class InputBuffer {
// Reads a single varint32.
Status ReadVarint32(uint32* result);
+ // Reads a single varint64.
+ Status ReadVarint64(uint64* result);
+
// Like ReadNBytes() without returning the bytes read.
Status SkipNBytes(int64 bytes_to_skip);
@@ -82,6 +85,15 @@ class InputBuffer {
// Internal slow-path routine used by ReadVarint32().
Status ReadVarint32Fallback(uint32* result);
+ // Internal slow-path routine used by ReadVarint64().
+ Status ReadVarint64Fallback(uint64* result);
+
+ // Helper method for reading a varint which can span at max `max_bytes`.
+ // If the varint is longer, a DataLoss error status is returned.
+ // If end of file is reached while reading, OutOfRange error is returned.
+ template <typename T>
+ Status ReadVarintFallback(T* result, int max_bytes);
+
RandomAccessFile* file_; // Not owned
int64 file_pos_; // Next position to read from in "file_"
size_t size_; // Size of "buf_"
@@ -109,6 +121,20 @@ inline Status InputBuffer::ReadVarint32(uint32* result) {
}
}
+// Inlined for performance.
+inline Status InputBuffer::ReadVarint64(uint64* result) {
+ if (pos_ + core::kMaxVarint64Bytes <= limit_) {
+ // Fast path: directly parse from buffered data.
+ // Reads strictly from the range [pos_, limit_).
+ const char* offset = core::GetVarint64Ptr(pos_, limit_, result);
+ if (offset == nullptr) return errors::OutOfRange("Parsed past limit.");
+ pos_ = const_cast<char*>(offset);
+ return Status::OK();
+ } else {
+ return ReadVarint64Fallback(result);
+ }
+}
+
} // namespace io
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc
index 6771697a16..6be1f819c2 100644
--- a/tensorflow/core/lib/io/inputbuffer_test.cc
+++ b/tensorflow/core/lib/io/inputbuffer_test.cc
@@ -329,5 +329,44 @@ TEST(InputBuffer, ReadVarint32) {
}
}
+TEST(InputBuffer, ReadVarint64) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+
+ // Generates data.
+ std::vector<uint64> data;
+ uint64 i = 0;
+ for (; i < (1U << 10); i += 1) data.push_back(i);
+ for (; i < (1U << 15); i += 5) data.push_back(i);
+ for (; i < (1U << 31); i += 164817) data.push_back(i);
+ for (; i < (1ULL << 63); i += 16481797854795663UL) data.push_back(i);
+ data.push_back(std::numeric_limits<uint64>::max());
+
+ // Writes the varints.
+ {
+ std::unique_ptr<WritableFile> file;
+ TF_CHECK_OK(env->NewWritableFile(fname, &file));
+ string varint;
+ for (uint64 number : data) {
+ varint.clear();
+ core::PutVarint64(&varint, number);
+ TF_CHECK_OK(file->Append(StringPiece(varint)));
+ }
+ }
+
+ for (auto buf_size : BufferSizes()) {
+ std::unique_ptr<RandomAccessFile> file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ io::InputBuffer in(file.get(), buf_size);
+ uint64 result = 0;
+
+ for (uint64 expected : data) {
+ TF_ASSERT_OK(in.ReadVarint64(&result));
+ EXPECT_EQ(expected, result);
+ }
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadVarint64(&result)));
+ }
+}
+
} // namespace
} // namespace tensorflow