diff options
Diffstat (limited to 'tensorflow/core/lib/io')
31 files changed, 3387 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc new file mode 100644 index 0000000000..1ddaa2eb78 --- /dev/null +++ b/tensorflow/core/lib/io/block.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Decodes the blocks generated by block_builder.cc. + +#include "tensorflow/core/lib/io/block.h" + +#include <vector> +#include <algorithm> +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +inline uint32 Block::NumRestarts() const { + assert(size_ >= sizeof(uint32)); + return core::DecodeFixed32(data_ + size_ - sizeof(uint32)); +} + +Block::Block(const BlockContents& contents) + : data_(contents.data.data()), + size_(contents.data.size()), + owned_(contents.heap_allocated) { + if (size_ < sizeof(uint32)) { + size_ = 0; // Error marker + } else { + size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32); + if (NumRestarts() > max_restarts_allowed) { + // The size is too small for NumRestarts() + size_ = 0; + } else { + restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32); + } + } +} + +Block::~Block() { + if (owned_) { + delete[] data_; + } +} + +// Helper routine: decode the next block entry starting at "p", +// storing the number of shared key bytes, non_shared key bytes, +// and the length of the value in "*shared", "*non_shared", and +// "*value_length", respectively. Will not dereference past "limit". +// +// If any errors are detected, returns NULL. Otherwise, returns a +// pointer to the key delta (just past the three decoded values). +static inline const char* DecodeEntry(const char* p, const char* limit, + uint32* shared, uint32* non_shared, + uint32* value_length) { + if (limit - p < 3) return NULL; + *shared = reinterpret_cast<const unsigned char*>(p)[0]; + *non_shared = reinterpret_cast<const unsigned char*>(p)[1]; + *value_length = reinterpret_cast<const unsigned char*>(p)[2]; + if ((*shared | *non_shared | *value_length) < 128) { + // Fast path: all three values are encoded in one byte each + p += 3; + } else { + if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL; + } + + if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) { + return NULL; + } + return p; +} + +class Block::Iter : public Iterator { + private: + const char* const data_; // underlying block contents + uint32 const restarts_; // Offset of restart array (list of fixed32) + uint32 const num_restarts_; // Number of uint32 entries in restart array + + // current_ is offset in data_ of current entry. >= restarts_ if !Valid + uint32 current_; + uint32 restart_index_; // Index of restart block in which current_ falls + string key_; + StringPiece value_; + Status status_; + + inline int Compare(const StringPiece& a, const StringPiece& b) const { + return a.compare(b); + } + + // Return the offset in data_ just past the end of the current entry. + inline uint32 NextEntryOffset() const { + return (value_.data() + value_.size()) - data_; + } + + uint32 GetRestartPoint(uint32 index) { + assert(index < num_restarts_); + return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32)); + } + + void SeekToRestartPoint(uint32 index) { + key_.clear(); + restart_index_ = index; + // current_ will be fixed by ParseNextKey(); + + // ParseNextKey() starts at the end of value_, so set value_ accordingly + uint32 offset = GetRestartPoint(index); + value_ = StringPiece(data_ + offset, 0); + } + + public: + Iter(const char* data, uint32 restarts, uint32 num_restarts) + : data_(data), + restarts_(restarts), + num_restarts_(num_restarts), + current_(restarts_), + restart_index_(num_restarts_) { + assert(num_restarts_ > 0); + } + + virtual bool Valid() const { return current_ < restarts_; } + virtual Status status() const { return status_; } + virtual StringPiece key() const { + assert(Valid()); + return key_; + } + virtual StringPiece value() const { + assert(Valid()); + return value_; + } + + virtual void Next() { + assert(Valid()); + ParseNextKey(); + } + + virtual void Seek(const StringPiece& target) { + // Binary search in restart array to find the last restart point + // with a key < target + uint32 left = 0; + uint32 right = num_restarts_ - 1; + while (left < right) { + uint32 mid = (left + right + 1) / 2; + uint32 region_offset = GetRestartPoint(mid); + uint32 shared, non_shared, value_length; + const char* key_ptr = + DecodeEntry(data_ + region_offset, data_ + restarts_, &shared, + &non_shared, &value_length); + if (key_ptr == NULL || (shared != 0)) { + CorruptionError(); + return; + } + StringPiece mid_key(key_ptr, non_shared); + if (Compare(mid_key, target) < 0) { + // Key at "mid" is smaller than "target". Therefore all + // blocks before "mid" are uninteresting. + left = mid; + } else { + // Key at "mid" is >= "target". Therefore all blocks at or + // after "mid" are uninteresting. + right = mid - 1; + } + } + + // Linear search (within restart block) for first key >= target + SeekToRestartPoint(left); + while (true) { + if (!ParseNextKey()) { + return; + } + if (Compare(key_, target) >= 0) { + return; + } + } + } + + virtual void SeekToFirst() { + SeekToRestartPoint(0); + ParseNextKey(); + } + + private: + void CorruptionError() { + current_ = restarts_; + restart_index_ = num_restarts_; + status_ = errors::DataLoss("bad entry in block"); + key_.clear(); + value_.clear(); + } + + bool ParseNextKey() { + current_ = NextEntryOffset(); + const char* p = data_ + current_; + const char* limit = data_ + restarts_; // Restarts come right after data + if (p >= limit) { + // No more entries to return. Mark as invalid. + current_ = restarts_; + restart_index_ = num_restarts_; + return false; + } + + // Decode next entry + uint32 shared, non_shared, value_length; + p = DecodeEntry(p, limit, &shared, &non_shared, &value_length); + if (p == NULL || key_.size() < shared) { + CorruptionError(); + return false; + } else { + key_.resize(shared); + key_.append(p, non_shared); + value_ = StringPiece(p + non_shared, value_length); + while (restart_index_ + 1 < num_restarts_ && + GetRestartPoint(restart_index_ + 1) < current_) { + ++restart_index_; + } + return true; + } + } +}; + +Iterator* Block::NewIterator() { + if (size_ < sizeof(uint32)) { + return NewErrorIterator(errors::DataLoss("bad block contents")); + } + const uint32 num_restarts = NumRestarts(); + if (num_restarts == 0) { + return NewEmptyIterator(); + } else { + return new Iter(data_, restart_offset_, num_restarts); + } +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h new file mode 100644 index 0000000000..bf53245b8d --- /dev/null +++ b/tensorflow/core/lib/io/block.h @@ -0,0 +1,45 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_H_ +#define TENSORFLOW_LIB_IO_BLOCK_H_ + +#include <stddef.h> +#include <stdint.h> +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +struct BlockContents; + +class Block { + public: + // Initialize the block with the specified contents. + explicit Block(const BlockContents& contents); + + ~Block(); + + size_t size() const { return size_; } + Iterator* NewIterator(); + + private: + uint32 NumRestarts() const; + + const char* data_; + size_t size_; + uint32 restart_offset_; // Offset in data_ of restart array + bool owned_; // Block owns data_[] + + // No copying allowed + Block(const Block&); + void operator=(const Block&); + + class Iter; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_H_ diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc new file mode 100644 index 0000000000..d94048d744 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// BlockBuilder generates blocks where keys are prefix-compressed: +// +// When we store a key, we drop the prefix shared with the previous +// string. This helps reduce the space requirement significantly. +// Furthermore, once every K keys, we do not apply the prefix +// compression and store the entire key. We call this a "restart +// point". The tail end of the block stores the offsets of all of the +// restart points, and can be used to do a binary search when looking +// for a particular key. Values are stored as-is (without compression) +// immediately following the corresponding key. +// +// An entry for a particular key-value pair has the form: +// shared_bytes: varint32 +// unshared_bytes: varint32 +// value_length: varint32 +// key_delta: char[unshared_bytes] +// value: char[value_length] +// shared_bytes == 0 for restart points. +// +// The trailer of the block has the form: +// restarts: uint32[num_restarts] +// num_restarts: uint32 +// restarts[i] contains the offset within the block of the ith restart point. + +#include "tensorflow/core/lib/io/block_builder.h" + +#include <algorithm> +#include <assert.h> +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace table { + +BlockBuilder::BlockBuilder(const Options* options) + : options_(options), restarts_(), counter_(0), finished_(false) { + assert(options->block_restart_interval >= 1); + restarts_.push_back(0); // First restart point is at offset 0 +} + +void BlockBuilder::Reset() { + buffer_.clear(); + restarts_.clear(); + restarts_.push_back(0); // First restart point is at offset 0 + counter_ = 0; + finished_ = false; + last_key_.clear(); +} + +size_t BlockBuilder::CurrentSizeEstimate() const { + return (buffer_.size() + // Raw data buffer + restarts_.size() * sizeof(uint32) + // Restart array + sizeof(uint32)); // Restart array length +} + +StringPiece BlockBuilder::Finish() { + // Append restart array + for (size_t i = 0; i < restarts_.size(); i++) { + core::PutFixed32(&buffer_, restarts_[i]); + } + core::PutFixed32(&buffer_, restarts_.size()); + finished_ = true; + return StringPiece(buffer_); +} + +void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) { + StringPiece last_key_piece(last_key_); + assert(!finished_); + assert(counter_ <= options_->block_restart_interval); + assert(buffer_.empty() // No values yet? + || key.compare(last_key_piece) > 0); + size_t shared = 0; + if (counter_ < options_->block_restart_interval) { + // See how much sharing to do with previous string + const size_t min_length = std::min(last_key_piece.size(), key.size()); + while ((shared < min_length) && (last_key_piece[shared] == key[shared])) { + shared++; + } + } else { + // Restart compression + restarts_.push_back(buffer_.size()); + counter_ = 0; + } + const size_t non_shared = key.size() - shared; + + // Add "<shared><non_shared><value_size>" to buffer_ + core::PutVarint32(&buffer_, shared); + core::PutVarint32(&buffer_, non_shared); + core::PutVarint32(&buffer_, value.size()); + + // Add string delta to buffer_ followed by value + buffer_.append(key.data() + shared, non_shared); + buffer_.append(value.data(), value.size()); + + // Update state + last_key_.resize(shared); + last_key_.append(key.data() + shared, non_shared); + assert(StringPiece(last_key_) == key); + counter_++; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h new file mode 100644 index 0000000000..e07a647805 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.h @@ -0,0 +1,57 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ +#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ + +#include <vector> + +#include <stdint.h> +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace table { + +struct Options; + +class BlockBuilder { + public: + explicit BlockBuilder(const Options* options); + + // Reset the contents as if the BlockBuilder was just constructed. + void Reset(); + + // REQUIRES: Finish() has not been called since the last call to Reset(). + // REQUIRES: key is larger than any previously added key + void Add(const StringPiece& key, const StringPiece& value); + + // Finish building the block and return a slice that refers to the + // block contents. The returned slice will remain valid for the + // lifetime of this builder or until Reset() is called. + StringPiece Finish(); + + // Returns an estimate of the current (uncompressed) size of the block + // we are building. + size_t CurrentSizeEstimate() const; + + // Return true iff no entries have been added since the last Reset() + bool empty() const { return buffer_.empty(); } + + private: + const Options* options_; + string buffer_; // Destination buffer + std::vector<uint32> restarts_; // Restart points + int counter_; // Number of entries emitted since restart + bool finished_; // Has Finish() been called? + string last_key_; + + // No copying allowed + BlockBuilder(const BlockBuilder&); + void operator=(const BlockBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc new file mode 100644 index 0000000000..259cfc13dc --- /dev/null +++ b/tensorflow/core/lib/io/format.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/format.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +void BlockHandle::EncodeTo(string* dst) const { + // Sanity check that all fields have been set + assert(offset_ != ~static_cast<uint64>(0)); + assert(size_ != ~static_cast<uint64>(0)); + core::PutVarint64(dst, offset_); + core::PutVarint64(dst, size_); +} + +Status BlockHandle::DecodeFrom(StringPiece* input) { + if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { + return Status::OK(); + } else { + return errors::DataLoss("bad block handle"); + } +} + +void Footer::EncodeTo(string* dst) const { +#ifndef NDEBUG + const size_t original_size = dst->size(); +#endif + metaindex_handle_.EncodeTo(dst); + index_handle_.EncodeTo(dst); + dst->resize(2 * BlockHandle::kMaxEncodedLength); // Padding + core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber & 0xffffffffu)); + core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber >> 32)); + assert(dst->size() == original_size + kEncodedLength); +} + +Status Footer::DecodeFrom(StringPiece* input) { + const char* magic_ptr = input->data() + kEncodedLength - 8; + const uint32 magic_lo = core::DecodeFixed32(magic_ptr); + const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); + const uint64 magic = + ((static_cast<uint64>(magic_hi) << 32) | (static_cast<uint64>(magic_lo))); + if (magic != kTableMagicNumber) { + return errors::DataLoss("not an sstable (bad magic number)"); + } + + Status result = metaindex_handle_.DecodeFrom(input); + if (result.ok()) { + result = index_handle_.DecodeFrom(input); + } + if (result.ok()) { + // We skip over any leftover data (just padding for now) in "input" + const char* end = magic_ptr + 8; + *input = StringPiece(end, input->data() + input->size() - end); + } + return result; +} + +Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { + result->data = StringPiece(); + result->cachable = false; + result->heap_allocated = false; + + // Read the block contents as well as the type/crc footer. + // See table_builder.cc for the code that built this structure. + size_t n = static_cast<size_t>(handle.size()); + char* buf = new char[n + kBlockTrailerSize]; + StringPiece contents; + Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + if (!s.ok()) { + delete[] buf; + return s; + } + if (contents.size() != n + kBlockTrailerSize) { + delete[] buf; + return errors::DataLoss("truncated block read"); + } + + // Check the crc of the type and the block contents + const char* data = contents.data(); // Pointer to where Read put the data + // This checksum verification is optional. We leave it on for now + const bool verify_checksum = true; + if (verify_checksum) { + const uint32 crc = crc32c::Unmask(core::DecodeFixed32(data + n + 1)); + const uint32 actual = crc32c::Value(data, n + 1); + if (actual != crc) { + delete[] buf; + s = errors::DataLoss("block checksum mismatch"); + return s; + } + } + + switch (data[n]) { + case kNoCompression: + if (data != buf) { + // File implementation gave us pointer to some other data. + // Use it directly under the assumption that it will be live + // while the file is open. + delete[] buf; + result->data = StringPiece(data, n); + result->heap_allocated = false; + result->cachable = false; // Do not double-cache + } else { + result->data = StringPiece(buf, n); + result->heap_allocated = true; + result->cachable = true; + } + + // Ok + break; + case kSnappyCompression: { + size_t ulength = 0; + if (!port::Snappy_GetUncompressedLength(data, n, &ulength)) { + delete[] buf; + return errors::DataLoss("corrupted compressed block contents"); + } + char* ubuf = new char[ulength]; + if (!port::Snappy_Uncompress(data, n, ubuf)) { + delete[] buf; + delete[] ubuf; + return errors::DataLoss("corrupted compressed block contents"); + } + delete[] buf; + result->data = StringPiece(ubuf, ulength); + result->heap_allocated = true; + result->cachable = true; + break; + } + default: + delete[] buf; + return errors::DataLoss("bad block type"); + } + + return Status::OK(); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h new file mode 100644 index 0000000000..3121c41bb8 --- /dev/null +++ b/tensorflow/core/lib/io/format.h @@ -0,0 +1,99 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_FORMAT_H_ +#define TENSORFLOW_LIB_IO_FORMAT_H_ + +#include <string> +#include <stdint.h> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/io/table_builder.h" + +namespace tensorflow { +class RandomAccessFile; +namespace table { + +class Block; + +// BlockHandle is a pointer to the extent of a file that stores a data +// block or a meta block. +class BlockHandle { + public: + BlockHandle(); + + // The offset of the block in the file. + uint64 offset() const { return offset_; } + void set_offset(uint64 offset) { offset_ = offset; } + + // The size of the stored block + uint64 size() const { return size_; } + void set_size(uint64 size) { size_ = size; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Maximum encoding length of a BlockHandle + enum { kMaxEncodedLength = 10 + 10 }; + + private: + uint64 offset_; + uint64 size_; +}; + +// Footer encapsulates the fixed information stored at the tail +// end of every table file. +class Footer { + public: + Footer() {} + + // The block handle for the metaindex block of the table + const BlockHandle& metaindex_handle() const { return metaindex_handle_; } + void set_metaindex_handle(const BlockHandle& h) { metaindex_handle_ = h; } + + // The block handle for the index block of the table + const BlockHandle& index_handle() const { return index_handle_; } + void set_index_handle(const BlockHandle& h) { index_handle_ = h; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Encoded length of a Footer. Note that the serialization of a + // Footer will always occupy exactly this many bytes. It consists + // of two block handles and a magic number. + enum { kEncodedLength = 2 * BlockHandle::kMaxEncodedLength + 8 }; + + private: + BlockHandle metaindex_handle_; + BlockHandle index_handle_; +}; + +// kTableMagicNumber was picked by running +// echo http://code.google.com/p/leveldb/ | sha1sum +// and taking the leading 64 bits. +static const uint64 kTableMagicNumber = 0xdb4775248b80fb57ull; + +// 1-byte type + 32-bit crc +static const size_t kBlockTrailerSize = 5; + +struct BlockContents { + StringPiece data; // Actual contents of data + bool cachable; // True iff data can be cached + bool heap_allocated; // True iff caller should delete[] data.data() +}; + +// Read the block identified by "handle" from "file". On failure +// return non-OK. On success fill *result and return OK. +extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); + +// Implementation details follow. Clients should ignore, + +inline BlockHandle::BlockHandle() + : offset_(~static_cast<uint64>(0)), size_(~static_cast<uint64>(0)) {} + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_FORMAT_H_ diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc new file mode 100644 index 0000000000..8fa245a546 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.cc @@ -0,0 +1,112 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) + : file_(file), + file_pos_(0), + size_(buffer_bytes), + buf_(new char[size_]), + pos_(buf_), + limit_(buf_) {} + +InputBuffer::~InputBuffer() { + delete file_; + delete[] buf_; +} + +Status InputBuffer::FillBuffer() { + StringPiece data; + Status s = file_->Read(file_pos_, size_, &data, buf_); + if (data.data() != buf_) { + memmove(buf_, data.data(), data.size()); + } + pos_ = buf_; + limit_ = pos_ + data.size(); + file_pos_ += data.size(); + return s; +} + +Status InputBuffer::ReadLine(string* result) { + result->clear(); + int i; + Status s; + for (i = 0;; i++) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + char c = *pos_++; + if (c == '\n') { + // We don't append the '\n' to *result + return Status::OK(); + } + *result += c; + } + if (errors::IsOutOfRange(s) && !result->empty()) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) { + result->clear(); + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", + bytes_to_read); + } + result->reserve(bytes_to_read); + Status s; + while (result->size() < static_cast<size_t>(bytes_to_read)) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_copy = + std::min<int64>(limit_ - pos_, bytes_to_read - result->size()); + result->insert(result->size(), pos_, bytes_to_copy); + pos_ += bytes_to_copy; + } + if (errors::IsOutOfRange(s) && + (result->size() == static_cast<size_t>(bytes_to_read))) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::SkipNBytes(int64 bytes_to_skip) { + if (bytes_to_skip < 0) { + return errors::InvalidArgument("Can only skip forward, not ", + bytes_to_skip); + } + int64 bytes_skipped = 0; + Status s; + while (bytes_skipped < bytes_to_skip) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_advance = + std::min<int64>(limit_ - pos_, bytes_to_skip - bytes_skipped); + bytes_skipped += bytes_to_advance; + pos_ += bytes_to_advance; + } + if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { + return Status::OK(); + } + return s; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h new file mode 100644 index 0000000000..6879f30567 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -0,0 +1,62 @@ +#ifndef TENSORFLOW_LIB_IO_INPUTBUFFER_H_ +#define TENSORFLOW_LIB_IO_INPUTBUFFER_H_ + +#include <string> +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace io { + +// An InputBuffer provides a buffer on top of a RandomAccessFile. +// A given instance of an InputBuffer is NOT safe for concurrent use +// by multiple threads +class InputBuffer { + public: + // Create an InputBuffer for "file" with a buffer size of + // "buffer_bytes" bytes. Takes ownership of "file" and will + // delete it when the InputBuffer is destroyed. + InputBuffer(RandomAccessFile* file, size_t buffer_bytes); + ~InputBuffer(); + + // Read one text line of data into "*result" until end-of-file or a + // \n is read. (The \n is not included in the result.) Overwrites + // any existing data in *result. + // + // If successful, returns OK. If we are already at the end of the + // file, we return an OUT_OF_RANGE error. Otherwise, we return + // some other non-OK status. + Status ReadLine(string* result); + + // Reads bytes_to_read bytes into *result, overwriting *result. + // + // If successful, returns OK. If we there are not enough bytes to + // read before the end of the file, we return an OUT_OF_RANGE error. + // Otherwise, we return some other non-OK status. + Status ReadNBytes(int64 bytes_to_read, string* result); + + // Like ReadNBytes() without returning the bytes read. + Status SkipNBytes(int64 bytes_to_skip); + + // Returns the position in the file. + int64 Tell() const { return file_pos_ - (limit_ - pos_); } + + private: + Status FillBuffer(); + + RandomAccessFile* file_; // Owned + int64 file_pos_; // Next position to read from in "file_" + size_t size_; // Size of "buf_" + char* buf_; // The buffer itself + // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_" + char* pos_; // Current position in "buf" + char* limit_; // Just past end of valid data in "buf" + + TF_DISALLOW_COPY_AND_ASSIGN(InputBuffer); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_INPUTBUFFER_H_ diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc new file mode 100644 index 0000000000..34094f018c --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer_test.cc @@ -0,0 +1,174 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" + +#include "tensorflow/core/public/env.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +static std::vector<int> BufferSizes() { + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536}; +} + +TEST(InputBuffer, ReadLine_Empty) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, ""); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine1) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three\n"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_NoTrailingNewLine) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_EmptyLines) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\n\n\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(3, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(4, &read)); + EXPECT_EQ(read, "3456"); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(7, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, "789"); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +TEST(InputBuffer, SkipNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(3)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(2, &read)); + EXPECT_EQ(read, "34"); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(2)); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(1, &read)); + EXPECT_EQ(read, "7"); + EXPECT_EQ(8, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc new file mode 100644 index 0000000000..878e93a911 --- /dev/null +++ b/tensorflow/core/lib/io/iterator.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +Iterator::Iterator() { + cleanup_.function = NULL; + cleanup_.next = NULL; +} + +Iterator::~Iterator() { + if (cleanup_.function != NULL) { + (*cleanup_.function)(cleanup_.arg1, cleanup_.arg2); + for (Cleanup* c = cleanup_.next; c != NULL;) { + (*c->function)(c->arg1, c->arg2); + Cleanup* next = c->next; + delete c; + c = next; + } + } +} + +void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { + assert(func != NULL); + Cleanup* c; + if (cleanup_.function == NULL) { + c = &cleanup_; + } else { + c = new Cleanup; + c->next = cleanup_.next; + cleanup_.next = c; + } + c->function = func; + c->arg1 = arg1; + c->arg2 = arg2; +} + +namespace { +class EmptyIterator : public Iterator { + public: + EmptyIterator(const Status& s) : status_(s) {} + virtual bool Valid() const { return false; } + virtual void Seek(const StringPiece& target) {} + virtual void SeekToFirst() {} + virtual void Next() { assert(false); } + StringPiece key() const { + assert(false); + return StringPiece(); + } + StringPiece value() const { + assert(false); + return StringPiece(); + } + virtual Status status() const { return status_; } + + private: + Status status_; +}; +} // namespace + +Iterator* NewEmptyIterator() { return new EmptyIterator(Status::OK()); } + +Iterator* NewErrorIterator(const Status& status) { + return new EmptyIterator(status); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h new file mode 100644 index 0000000000..603a2f95fe --- /dev/null +++ b/tensorflow/core/lib/io/iterator.h @@ -0,0 +1,93 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// An iterator yields a sequence of key/value pairs from a source. +// The following class defines the interface. Multiple implementations +// are provided by this library. In particular, iterators are provided +// to access the contents of a Table or a DB. +// +// Multiple threads can invoke const methods on an Iterator without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same Iterator must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_ITERATOR_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace table { + +class Iterator { + public: + Iterator(); + virtual ~Iterator(); + + // An iterator is either positioned at a key/value pair, or + // not valid. This method returns true iff the iterator is valid. + virtual bool Valid() const = 0; + + // Position at the first key in the source. The iterator is Valid() + // after this call iff the source is not empty. + virtual void SeekToFirst() = 0; + + // Position at the first key in the source that is at or past target. + // The iterator is Valid() after this call iff the source contains + // an entry that comes at or past target. + virtual void Seek(const StringPiece& target) = 0; + + // Moves to the next entry in the source. After this call, Valid() is + // true iff the iterator was not positioned at the last entry in the source. + // REQUIRES: Valid() + virtual void Next() = 0; + + // Return the key for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece key() const = 0; + + // Return the value for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece value() const = 0; + + // If an error has occurred, return it. Else return an ok status. + virtual Status status() const = 0; + + // Clients are allowed to register function/arg1/arg2 triples that + // will be invoked when this iterator is destroyed. + // + // Note that unlike all of the preceding methods, this method is + // not abstract and therefore clients should not override it. + typedef void (*CleanupFunction)(void* arg1, void* arg2); + void RegisterCleanup(CleanupFunction function, void* arg1, void* arg2); + + private: + struct Cleanup { + CleanupFunction function; + void* arg1; + void* arg2; + Cleanup* next; + }; + Cleanup cleanup_; + + // No copying allowed + Iterator(const Iterator&); + void operator=(const Iterator&); +}; + +// Return an empty iterator (yields nothing). +extern Iterator* NewEmptyIterator(); + +// Return an empty iterator with the specified status. +extern Iterator* NewErrorIterator(const Status& status); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_ITERATOR_H_ diff --git a/tensorflow/core/lib/io/match.cc b/tensorflow/core/lib/io/match.cc new file mode 100644 index 0000000000..1563642d0b --- /dev/null +++ b/tensorflow/core/lib/io/match.cc @@ -0,0 +1,31 @@ +#include "tensorflow/core/lib/io/match.h" +#include <fnmatch.h> +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector<string>* results) { + results->clear(); + std::vector<string> all_files; + string dir = Dirname(pattern).ToString(); + if (dir.empty()) dir = "."; + string basename_pattern = Basename(pattern).ToString(); + Status s = env->GetChildren(dir, &all_files); + if (!s.ok()) { + return s; + } + for (const auto& f : all_files) { + int flags = 0; + if (fnmatch(basename_pattern.c_str(), Basename(f).ToString().c_str(), + flags) == 0) { + results->push_back(JoinPath(dir, f)); + } + } + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/match.h b/tensorflow/core/lib/io/match.h new file mode 100644 index 0000000000..fd194178e7 --- /dev/null +++ b/tensorflow/core/lib/io/match.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_LIB_IO_MATCH_H_ +#define TENSORFLOW_LIB_IO_MATCH_H_ + +#include <vector> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +class Env; +namespace io { + +// Given a pattern, return the set of files that match the pattern. +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. If +// successful, return Status::OK and store the matching files in +// "*results". Otherwise, return a non-OK status. +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector<string>* results); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_MATCH_H_ diff --git a/tensorflow/core/lib/io/match_test.cc b/tensorflow/core/lib/io/match_test.cc new file mode 100644 index 0000000000..aaa56e4e7e --- /dev/null +++ b/tensorflow/core/lib/io/match_test.cc @@ -0,0 +1,51 @@ +#include <algorithm> +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace io { + +static string Match(Env* env, const string& suffix_pattern) { + std::vector<string> results; + Status s = GetMatchingFiles(env, JoinPath(testing::TmpDir(), suffix_pattern), + &results); + if (!s.ok()) { + return s.ToString(); + } else { + string r; + std::sort(results.begin(), results.end()); + for (size_t i = 0; i < results.size(); i++) { + strings::StrAppend(&r, (i > 0) ? "," : "", Basename(results[i])); + } + return r; + } +} +TEST(GetMatchingFiles, Simple) { + Env* env = Env::Default(); + EXPECT_EQ(Match(env, "thereisnosuchfile"), ""); + EXPECT_EQ(Match(env, "thereisnosuchfile*"), ""); + + // Populate a few files + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-00"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-0a"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-01"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-aaa"), "")); + + EXPECT_EQ(Match(env, "match-*"), "match-00,match-01,match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-0[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?a*"), "match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-??"), "match-00,match-01,match-0a"); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc new file mode 100644 index 0000000000..1359ded0f0 --- /dev/null +++ b/tensorflow/core/lib/io/path.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace io { + +string JoinPath(StringPiece part1, StringPiece part2) { + string result; + + StringPiece paths[2] = {part1, part2}; + for (StringPiece path : paths) { + if (path.empty()) continue; + + if (result.empty()) { + result = path.ToString(); + continue; + } + + if (result[result.size() - 1] == '/') { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path.substr(1)); + } else { + strings::StrAppend(&result, path); + } + } else { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path); + } else { + strings::StrAppend(&result, "/", path); + } + } + } + + return result; +} + +namespace internal { + +// Return the parts of the path, split on the final "/". If there is no +// "/" in the path, the first part of the output is empty and the second +// is the input. If the only "/" in the path is the first character, it is +// the first part of the output. +std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) { + auto pos = path.rfind('/'); + + // Handle the case with no '/' in 'path'. + if (pos == StringPiece::npos) + return std::make_pair(StringPiece(path.data(), 0), path); + + // Handle the case with a single leading '/' in 'path'. + if (pos == 0) + return std::make_pair(StringPiece(path.data(), 1), + StringPiece(path.data() + 1, path.size() - 1)); + + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} + +// Return the parts of the basename of path, split on the final ".". +// If there is no "." in the basename or "." is the final character in the +// basename, the second value will be empty. +std::pair<StringPiece, StringPiece> SplitBasename(StringPiece path) { + path = Basename(path); + + auto pos = path.rfind('.'); + if (pos == StringPiece::npos) + return std::make_pair(path, StringPiece(path.data() + path.size(), 0)); + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} +} // namespace internal + +bool IsAbsolutePath(StringPiece path) { + return !path.empty() && path[0] == '/'; +} + +StringPiece Dirname(StringPiece path) { + return internal::SplitPath(path).first; +} + +StringPiece Basename(StringPiece path) { + return internal::SplitPath(path).second; +} + +StringPiece Extension(StringPiece path) { + return internal::SplitBasename(path).second; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h new file mode 100644 index 0000000000..01483f1702 --- /dev/null +++ b/tensorflow/core/lib/io/path.h @@ -0,0 +1,47 @@ +#ifndef TENSORFLOW_LIB_IO_PATH_H_ +#define TENSORFLOW_LIB_IO_PATH_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class StringPiece; +namespace io { + +// Utility routines for processing filenames + +// Join multiple paths together, without introducing unnecessary path +// separators. +// For example: +// +// Arguments | JoinPath +// ---------------------------+---------- +// '/foo', 'bar' | /foo/bar +// '/foo/', 'bar' | /foo/bar +// '/foo', '/bar' | /foo/bar +// +// Usage: +// string path = io::JoinPath("/mydir", filename); +// string path = io::JoinPath(FLAGS_test_srcdir, filename); +string JoinPath(StringPiece part1, StringPiece part2); + +// Return true if path is absolute. +bool IsAbsolutePath(StringPiece path); + +// Returns the part of the path before the final "/". If there is a single +// leading "/" in the path, the result will be the leading "/". If there is +// no "/" in the path, the result is the empty prefix of the input. +StringPiece Dirname(StringPiece path); + +// Returns the part of the path after the final "/". If there is no +// "/" in the path, the result is the same as the input. +StringPiece Basename(StringPiece path); + +// Returns the part of the basename of path after the final ".". If +// there is no "." in the basename, the result is empty. +StringPiece Extension(StringPiece path); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_PATH_H_ diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc new file mode 100644 index 0000000000..b670e44f1f --- /dev/null +++ b/tensorflow/core/lib/io/path_test.cc @@ -0,0 +1,65 @@ +#include "tensorflow/core/lib/io/path.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace io { + +TEST(PathTest, JoinPath) { + EXPECT_EQ("/foo/bar", JoinPath("/foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "/bar")); + EXPECT_EQ("/foo/bar", JoinPath("/foo", "/bar")); + + EXPECT_EQ("/bar", JoinPath("", "/bar")); + EXPECT_EQ("bar", JoinPath("", "bar")); + EXPECT_EQ("/foo", JoinPath("/foo", "")); + + EXPECT_EQ("/foo/bar/baz/blah/blink/biz", + JoinPath("/foo/bar/baz/", "/blah/blink/biz")); +} + +TEST(PathTest, IsAbsolutePath) { + EXPECT_FALSE(IsAbsolutePath("")); + EXPECT_FALSE(IsAbsolutePath("../foo")); + EXPECT_FALSE(IsAbsolutePath("foo")); + EXPECT_FALSE(IsAbsolutePath("./foo")); + EXPECT_FALSE(IsAbsolutePath("foo/bar/baz/")); + EXPECT_TRUE(IsAbsolutePath("/foo")); + EXPECT_TRUE(IsAbsolutePath("/foo/bar/../baz")); +} + +TEST(PathTest, Dirname) { + EXPECT_EQ("/hello", Dirname("/hello/")); + EXPECT_EQ("/", Dirname("/hello")); + EXPECT_EQ("hello", Dirname("hello/world")); + EXPECT_EQ("hello", Dirname("hello/")); + EXPECT_EQ("", Dirname("world")); + EXPECT_EQ("/", Dirname("/")); + EXPECT_EQ("", Dirname("")); +} + +TEST(PathTest, Basename) { + EXPECT_EQ("", Basename("/hello/")); + EXPECT_EQ("hello", Basename("/hello")); + EXPECT_EQ("world", Basename("hello/world")); + EXPECT_EQ("", Basename("hello/")); + EXPECT_EQ("world", Basename("world")); + EXPECT_EQ("", Basename("/")); + EXPECT_EQ("", Basename("")); +} + +TEST(PathTest, Extension) { + EXPECT_EQ("gif", Extension("foo.gif")); + EXPECT_EQ("", Extension("foo.")); + EXPECT_EQ("", Extension("")); + EXPECT_EQ("", Extension("/")); + EXPECT_EQ("", Extension("foo")); + EXPECT_EQ("", Extension("foo/")); + EXPECT_EQ("gif", Extension("/a/path/to/foo.gif")); + EXPECT_EQ("html", Extension("/a/path.bar/to/foo.html")); + EXPECT_EQ("", Extension("/a/path.bar/to/foo")); + EXPECT_EQ("baz", Extension("/a/path.bar/to/foo.bar.baz")); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc new file mode 100644 index 0000000000..2f0fabff63 --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.cc @@ -0,0 +1,80 @@ +#include "tensorflow/core/lib/io/record_reader.h" + +#include <limits.h> +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {} + +RecordReader::~RecordReader() {} + +// Read n+4 bytes from file, verify that checksum of first n bytes is +// stored in the last 4 bytes and store the first n bytes in *result. +// May use *storage as backing store. +static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, + size_t n, StringPiece* result, + string* storage) { + if (n >= SIZE_MAX - sizeof(uint32)) { + return errors::DataLoss("record size too large"); + } + + const size_t expected = n + sizeof(uint32); + storage->resize(expected); + StringPiece data; + Status s = file->Read(offset, expected, &data, &(*storage)[0]); + if (!s.ok()) { + return s; + } + if (data.size() != expected) { + if (data.size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); + return Status::OK(); +} + +Status RecordReader::ReadRecord(uint64* offset, string* record) { + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + + // Read length + StringPiece lbuf; + Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record); + if (!s.ok()) { + return s; + } + const uint64 length = core::DecodeFixed64(lbuf.data()); + + // Read data + StringPiece data; + s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { + s = errors::DataLoss("truncated record at ", *offset); + } + return s; + } + if (record->data() != data.data()) { + // RandomAccessFile placed the data in some other location. + memmove(&(*record)[0], data.data(), data.size()); + } + + record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h new file mode 100644 index 0000000000..a8c1b0dd5d --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_ +#define TENSORFLOW_LIB_IO_RECORD_READER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class RandomAccessFile; + +namespace io { + +class RecordReader { + public: + // Create a reader that will return log records from "*file". + // "*file" must remain live while this Reader is in use. + explicit RecordReader(RandomAccessFile* file); + + ~RecordReader(); + + // Read the record at "*offset" into *record and update *offset to + // point to the offset of the next record. Returns OK on success, + // OUT_OF_RANGE for end of file, or something else for an error. + Status ReadRecord(uint64* offset, string* record); + + private: + RandomAccessFile* src_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_ diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc new file mode 100644 index 0000000000..3d7f1509ab --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/lib/io/record_writer.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" + +namespace tensorflow { +namespace io { + +RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {} + +RecordWriter::~RecordWriter() {} + +static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); +} + +Status RecordWriter::WriteRecord(StringPiece data) { + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + char header[sizeof(uint64) + sizeof(uint32)]; + core::EncodeFixed64(header + 0, data.size()); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); + Status s = dest_->Append(StringPiece(header, sizeof(header))); + if (!s.ok()) { + return s; + } + s = dest_->Append(data); + if (!s.ok()) { + return s; + } + char footer[sizeof(uint32)]; + core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); + return dest_->Append(StringPiece(footer, sizeof(footer))); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h new file mode 100644 index 0000000000..c7af00e5ae --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.h @@ -0,0 +1,34 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_ +#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class WritableFile; + +namespace io { + +class RecordWriter { + public: + // Create a writer that will append data to "*dest". + // "*dest" must be initially empty. + // "*dest" must remain live while this Writer is in use. + explicit RecordWriter(WritableFile* dest); + + ~RecordWriter(); + + Status WriteRecord(StringPiece slice); + + private: + WritableFile* const dest_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_ diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc new file mode 100644 index 0000000000..3e9c816443 --- /dev/null +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -0,0 +1,245 @@ +#include "tensorflow/core/lib/io/record_reader.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +// Construct a string of the specified length made out of the supplied +// partial string. +static string BigString(const string& partial_string, size_t n) { + string result; + while (result.size() < n) { + result.append(partial_string); + } + result.resize(n); + return result; +} + +// Construct a string from a number +static string NumberString(int n) { + char buf[50]; + snprintf(buf, sizeof(buf), "%d.", n); + return string(buf); +} + +// Return a skewed potentially long string +static string RandomSkewedString(int i, random::SimplePhilox* rnd) { + return BigString(NumberString(i), rnd->Skewed(17)); +} + +class RecordioTest : public testing::Test { + private: + class StringDest : public WritableFile { + public: + string contents_; + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + Status Append(const StringPiece& slice) override { + contents_.append(slice.data(), slice.size()); + return Status::OK(); + } + }; + + class StringSource : public RandomAccessFile { + public: + StringPiece contents_; + mutable bool force_error_; + mutable bool returned_partial_; + StringSource() : force_error_(false), returned_partial_(false) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; + + if (force_error_) { + force_error_ = false; + returned_partial_ = true; + return errors::DataLoss("read error"); + } + + if (offset >= contents_.size()) { + return errors::OutOfRange("end of file"); + } + + if (contents_.size() < offset + n) { + n = contents_.size() - offset; + returned_partial_ = true; + } + *result = StringPiece(contents_.data() + offset, n); + return Status::OK(); + } + }; + + StringDest dest_; + StringSource source_; + bool reading_; + uint64 readpos_; + RecordWriter* writer_; + RecordReader* reader_; + + public: + RecordioTest() + : reading_(false), + readpos_(0), + writer_(new RecordWriter(&dest_)), + reader_(new RecordReader(&source_)) {} + + ~RecordioTest() override { + delete writer_; + delete reader_; + } + + void Write(const string& msg) { + ASSERT_TRUE(!reading_) << "Write() after starting to read"; + ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); + } + + size_t WrittenBytes() const { return dest_.contents_.size(); } + + string Read() { + if (!reading_) { + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + } + string record; + Status s = reader_->ReadRecord(&readpos_, &record); + if (s.ok()) { + return record; + } else if (errors::IsOutOfRange(s)) { + return "EOF"; + } else { + return s.ToString(); + } + } + + void IncrementByte(int offset, int delta) { + dest_.contents_[offset] += delta; + } + + void SetByte(int offset, char new_byte) { + dest_.contents_[offset] = new_byte; + } + + void ShrinkSize(int bytes) { + dest_.contents_.resize(dest_.contents_.size() - bytes); + } + + void FixChecksum(int header_offset, int len) { + // Compute crc of type/len/data + uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); + crc = crc32c::Mask(crc); + core::EncodeFixed32(&dest_.contents_[header_offset], crc); + } + + void ForceError() { source_.force_error_ = true; } + + void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } + + void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) { + Write("foo"); + Write("bar"); + Write(BigString("x", 10000)); + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + uint64 offset = WrittenBytes() + offset_past_end; + string record; + Status s = reader_->ReadRecord(&offset, &record); + ASSERT_TRUE(errors::IsOutOfRange(s)) << s; + } +}; + +TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); } + +TEST_F(RecordioTest, ReadWrite) { + Write("foo"); + Write("bar"); + Write(""); + Write("xxxx"); + ASSERT_EQ("foo", Read()); + ASSERT_EQ("bar", Read()); + ASSERT_EQ("", Read()); + ASSERT_EQ("xxxx", Read()); + ASSERT_EQ("EOF", Read()); + ASSERT_EQ("EOF", Read()); // Make sure reads at eof work +} + +TEST_F(RecordioTest, ManyRecords) { + for (int i = 0; i < 100000; i++) { + Write(NumberString(i)); + } + for (int i = 0; i < 100000; i++) { + ASSERT_EQ(NumberString(i), Read()); + } + ASSERT_EQ("EOF", Read()); +} + +TEST_F(RecordioTest, RandomRead) { + const int N = 500; + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + Write(RandomSkewedString(i, &rnd)); + } + } + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + ASSERT_EQ(RandomSkewedString(i, &rnd), Read()); + } + } + ASSERT_EQ("EOF", Read()); +} + +// Tests of all the error paths in log_reader.cc follow: +static void AssertHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " + << expected; +} + +TEST_F(RecordioTest, ReadError) { + Write("foo"); + ForceError(); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLength) { + Write("foo"); + IncrementByte(6, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLengthCrc) { + Write("foo"); + IncrementByte(10, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptData) { + Write("foo"); + IncrementByte(14, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptDataCrc) { + Write("foo"); + IncrementByte(WrittenBytes() - 1, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } + +TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc new file mode 100644 index 0000000000..769d7e72a5 --- /dev/null +++ b/tensorflow/core/lib/io/table.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/io/two_level_iterator.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +struct Table::Rep { + ~Rep() { delete index_block; } + + Options options; + Status status; + RandomAccessFile* file; + // XXX uint64 cache_id; + + BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer + Block* index_block; +}; + +Status Table::Open(const Options& options, RandomAccessFile* file, + uint64 size, Table** table) { + *table = NULL; + if (size < Footer::kEncodedLength) { + return errors::DataLoss("file is too short to be an sstable"); + } + + char footer_space[Footer::kEncodedLength]; + StringPiece footer_input; + Status s = + file->Read(size - Footer::kEncodedLength, Footer::kEncodedLength, + &footer_input, footer_space); + if (!s.ok()) return s; + + Footer footer; + s = footer.DecodeFrom(&footer_input); + if (!s.ok()) return s; + + // Read the index block + BlockContents contents; + Block* index_block = NULL; + if (s.ok()) { + s = ReadBlock(file, footer.index_handle(), &contents); + if (s.ok()) { + index_block = new Block(contents); + } + } + + if (s.ok()) { + // We've successfully read the footer and the index block: we're + // ready to serve requests. + Rep* rep = new Table::Rep; + rep->options = options; + rep->file = file; + rep->metaindex_handle = footer.metaindex_handle(); + rep->index_block = index_block; + // XXX rep->cache_id = (options.block_cache ? + // options.block_cache->NewId() : 0); + *table = new Table(rep); + } else { + if (index_block) delete index_block; + } + + return s; +} + +Table::~Table() { delete rep_; } + +static void DeleteBlock(void* arg, void* ignored) { + delete reinterpret_cast<Block*>(arg); +} + +// Convert an index iterator value (i.e., an encoded BlockHandle) +// into an iterator over the contents of the corresponding block. +Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { + Table* table = reinterpret_cast<Table*>(arg); + // Cache* block_cache = table->rep_->options.block_cache; + Block* block = NULL; + // Cache::Handle* cache_handle = NULL; + + BlockHandle handle; + StringPiece input = index_value; + Status s = handle.DecodeFrom(&input); + // We intentionally allow extra stuff in index_value so that we + // can add more features in the future. + + if (s.ok()) { + BlockContents contents; + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + } + } + + Iterator* iter; + if (block != NULL) { + iter = block->NewIterator(); + iter->RegisterCleanup(&DeleteBlock, block, NULL); + } else { + iter = NewErrorIterator(s); + } + return iter; +} + +Iterator* Table::NewIterator() const { + return NewTwoLevelIterator(rep_->index_block->NewIterator(), + &Table::BlockReader, const_cast<Table*>(this)); +} + +Status Table::InternalGet(const StringPiece& k, void* arg, + void (*saver)(void*, const StringPiece&, + const StringPiece&)) { + Status s; + Iterator* iiter = rep_->index_block->NewIterator(); + iiter->Seek(k); + if (iiter->Valid()) { + BlockHandle handle; + Iterator* block_iter = BlockReader(this, iiter->value()); + block_iter->Seek(k); + if (block_iter->Valid()) { + (*saver)(arg, block_iter->key(), block_iter->value()); + } + s = block_iter->status(); + delete block_iter; + } + if (s.ok()) { + s = iiter->status(); + } + delete iiter; + return s; +} + +uint64 Table::ApproximateOffsetOf(const StringPiece& key) const { + Iterator* index_iter = rep_->index_block->NewIterator(); + index_iter->Seek(key); + uint64 result; + if (index_iter->Valid()) { + BlockHandle handle; + StringPiece input = index_iter->value(); + Status s = handle.DecodeFrom(&input); + if (s.ok()) { + result = handle.offset(); + } else { + // Strange: we can't decode the block handle in the index block. + // We'll just return the offset of the metaindex block, which is + // close to the whole file size for this case. + result = rep_->metaindex_handle.offset(); + } + } else { + // key is past the last key in the file. Approximate the offset + // by returning the offset of the metaindex block (which is + // right near the end of the file). + result = rep_->metaindex_handle.offset(); + } + delete index_iter; + return result; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h new file mode 100644 index 0000000000..230dded2d4 --- /dev/null +++ b/tensorflow/core/lib/io/table.h @@ -0,0 +1,76 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TABLE_H_ +#define TENSORFLOW_LIB_IO_TABLE_H_ + +#include <stdint.h> +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +class RandomAccessFile; + +namespace table { + +class Block; +class BlockHandle; +class Footer; +struct Options; + +// A Table is a sorted map from strings to strings. Tables are +// immutable and persistent. A Table may be safely accessed from +// multiple threads without external synchronization. +class Table { + public: + // Attempt to open the table that is stored in bytes [0..file_size) + // of "file", and read the metadata entries necessary to allow + // retrieving data from the table. + // + // If successful, returns ok and sets "*table" to the newly opened + // table. The client should delete "*table" when no longer needed. + // If there was an error while initializing the table, sets "*table" + // to NULL and returns a non-ok status. Does not take ownership of + // "*file", but the client must ensure that "file" remains live + // for the duration of the returned table's lifetime. + static Status Open(const Options& options, RandomAccessFile* file, + uint64 file_size, Table** table); + + ~Table(); + + // Returns a new iterator over the table contents. + // The result of NewIterator() is initially invalid (caller must + // call one of the Seek methods on the iterator before using it). + Iterator* NewIterator() const; + + // Given a key, return an approximate byte offset in the file where + // the data for that key begins (or would begin if the key were + // present in the file). The returned value is in terms of file + // bytes, and so includes effects like compression of the underlying data. + // E.g., the approximate offset of the last key in the table will + // be close to the file length. + uint64 ApproximateOffsetOf(const StringPiece& key) const; + + private: + struct Rep; + Rep* rep_; + + explicit Table(Rep* rep) { rep_ = rep; } + static Iterator* BlockReader(void*, const StringPiece&); + + // Calls (*handle_result)(arg, ...) with the entry found after a call + // to Seek(key). May not make such a call if filter policy says + // that key is not present. + Status InternalGet(const StringPiece& key, void* arg, + void (*handle_result)(void* arg, const StringPiece& k, + const StringPiece& v)); + + // No copying allowed + Table(const Table&); + void operator=(const Table&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_H_ diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc new file mode 100644 index 0000000000..b786888b30 --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.cc @@ -0,0 +1,263 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table_builder.h" + +#include <assert.h> +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +namespace { + +void FindShortestSeparator(string* start, const StringPiece& limit) { + // Find length of common prefix + size_t min_length = std::min(start->size(), limit.size()); + size_t diff_index = 0; + while ((diff_index < min_length) && + ((*start)[diff_index] == limit[diff_index])) { + diff_index++; + } + + if (diff_index >= min_length) { + // Do not shorten if one string is a prefix of the other + } else { + uint8 diff_byte = static_cast<uint8>((*start)[diff_index]); + if (diff_byte < static_cast<uint8>(0xff) && + diff_byte + 1 < static_cast<uint8>(limit[diff_index])) { + (*start)[diff_index]++; + start->resize(diff_index + 1); + assert(StringPiece(*start).compare(limit) < 0); + } + } +} + +void FindShortSuccessor(string* key) { + // Find first character that can be incremented + size_t n = key->size(); + for (size_t i = 0; i < n; i++) { + const uint8 byte = (*key)[i]; + if (byte != static_cast<uint8>(0xff)) { + (*key)[i] = byte + 1; + key->resize(i + 1); + return; + } + } + // *key is a run of 0xffs. Leave it alone. +} +} // namespace + +struct TableBuilder::Rep { + Options options; + Options index_block_options; + WritableFile* file; + uint64 offset; + Status status; + BlockBuilder data_block; + BlockBuilder index_block; + string last_key; + int64 num_entries; + bool closed; // Either Finish() or Abandon() has been called. + + // We do not emit the index entry for a block until we have seen the + // first key for the next data block. This allows us to use shorter + // keys in the index block. For example, consider a block boundary + // between the keys "the quick brown fox" and "the who". We can use + // "the r" as the key for the index block entry since it is >= all + // entries in the first block and < all entries in subsequent + // blocks. + // + // Invariant: r->pending_index_entry is true only if data_block is empty. + bool pending_index_entry; + BlockHandle pending_handle; // Handle to add to index block + + string compressed_output; + + Rep(const Options& opt, WritableFile* f) + : options(opt), + index_block_options(opt), + file(f), + offset(0), + data_block(&options), + index_block(&index_block_options), + num_entries(0), + closed(false), + pending_index_entry(false) { + index_block_options.block_restart_interval = 1; + } +}; + +TableBuilder::TableBuilder(const Options& options, WritableFile* file) + : rep_(new Rep(options, file)) {} + +TableBuilder::~TableBuilder() { + assert(rep_->closed); // Catch errors where caller forgot to call Finish() + delete rep_; +} + +void TableBuilder::Add(const StringPiece& key, const StringPiece& value) { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->num_entries > 0) { + assert(key.compare(StringPiece(r->last_key)) > 0); + // See if this key+value would make our current block overly large. If + // so, emit the current block before adding this key/value + const int kOverlyLargeBlockRatio = 2; + const size_t this_entry_bytes = key.size() + value.size(); + if (this_entry_bytes >= kOverlyLargeBlockRatio * r->options.block_size) { + Flush(); + } + } + + if (r->pending_index_entry) { + assert(r->data_block.empty()); + FindShortestSeparator(&r->last_key, key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + + r->last_key.assign(key.data(), key.size()); + r->num_entries++; + r->data_block.Add(key, value); + + const size_t estimated_block_size = r->data_block.CurrentSizeEstimate(); + if (estimated_block_size >= r->options.block_size) { + Flush(); + } +} + +void TableBuilder::Flush() { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->data_block.empty()) return; + assert(!r->pending_index_entry); + WriteBlock(&r->data_block, &r->pending_handle); + if (ok()) { + r->pending_index_entry = true; + r->status = r->file->Flush(); + } +} + +void TableBuilder::WriteBlock(BlockBuilder* block, BlockHandle* handle) { + // File format contains a sequence of blocks where each block has: + // block_data: uint8[n] + // type: uint8 + // crc: uint32 + assert(ok()); + Rep* r = rep_; + StringPiece raw = block->Finish(); + + StringPiece block_contents; + CompressionType type = r->options.compression; + // TODO(postrelease): Support more compression options: zlib? + switch (type) { + case kNoCompression: + block_contents = raw; + break; + + case kSnappyCompression: { + string* compressed = &r->compressed_output; + if (port::Snappy_Compress(raw.data(), raw.size(), compressed) && + compressed->size() < raw.size() - (raw.size() / 8u)) { + block_contents = *compressed; + } else { + // Snappy not supported, or compressed less than 12.5%, so just + // store uncompressed form + block_contents = raw; + type = kNoCompression; + } + break; + } + } + WriteRawBlock(block_contents, type, handle); + r->compressed_output.clear(); + block->Reset(); +} + +void TableBuilder::WriteRawBlock(const StringPiece& block_contents, + CompressionType type, BlockHandle* handle) { + Rep* r = rep_; + handle->set_offset(r->offset); + handle->set_size(block_contents.size()); + r->status = r->file->Append(block_contents); + if (r->status.ok()) { + char trailer[kBlockTrailerSize]; + trailer[0] = type; + uint32 crc = crc32c::Value(block_contents.data(), block_contents.size()); + crc = crc32c::Extend(crc, trailer, 1); // Extend crc to cover block type + core::EncodeFixed32(trailer + 1, crc32c::Mask(crc)); + r->status = r->file->Append(StringPiece(trailer, kBlockTrailerSize)); + if (r->status.ok()) { + r->offset += block_contents.size() + kBlockTrailerSize; + } + } +} + +Status TableBuilder::status() const { return rep_->status; } + +Status TableBuilder::Finish() { + Rep* r = rep_; + Flush(); + assert(!r->closed); + r->closed = true; + + BlockHandle metaindex_block_handle, index_block_handle; + + // Write metaindex block + if (ok()) { + BlockBuilder meta_index_block(&r->options); + // TODO(postrelease): Add stats and other meta blocks + WriteBlock(&meta_index_block, &metaindex_block_handle); + } + + // Write index block + if (ok()) { + if (r->pending_index_entry) { + FindShortSuccessor(&r->last_key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + WriteBlock(&r->index_block, &index_block_handle); + } + + // Write footer + if (ok()) { + Footer footer; + footer.set_metaindex_handle(metaindex_block_handle); + footer.set_index_handle(index_block_handle); + string footer_encoding; + footer.EncodeTo(&footer_encoding); + r->status = r->file->Append(footer_encoding); + if (r->status.ok()) { + r->offset += footer_encoding.size(); + } + } + return r->status; +} + +void TableBuilder::Abandon() { + Rep* r = rep_; + assert(!r->closed); + r->closed = true; +} + +uint64 TableBuilder::NumEntries() const { return rep_->num_entries; } + +uint64 TableBuilder::FileSize() const { return rep_->offset; } + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h new file mode 100644 index 0000000000..cebf4d8e0c --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.h @@ -0,0 +1,87 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// TableBuilder provides the interface used to build a Table +// (an immutable and sorted map from keys to values). +// +// Multiple threads can invoke const methods on a TableBuilder without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same TableBuilder must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ +#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ + +#include <stdint.h> +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class WritableFile; +namespace table { + +class BlockBuilder; +class BlockHandle; + +class TableBuilder { + public: + // Create a builder that will store the contents of the table it is + // building in *file. Does not close the file. It is up to the + // caller to close the file after calling Finish(). + TableBuilder(const Options& options, WritableFile* file); + + // REQUIRES: Either Finish() or Abandon() has been called. + ~TableBuilder(); + + // Add key,value to the table being constructed. + // REQUIRES: key is after any previously added key in lexicographic order. + // REQUIRES: Finish(), Abandon() have not been called + void Add(const StringPiece& key, const StringPiece& value); + + // Advanced operation: flush any buffered key/value pairs to file. + // Can be used to ensure that two adjacent entries never live in + // the same data block. Most clients should not need to use this method. + // REQUIRES: Finish(), Abandon() have not been called + void Flush(); + + // Return non-ok iff some error has been detected. + Status status() const; + + // Finish building the table. Stops using the file passed to the + // constructor after this function returns. + // REQUIRES: Finish(), Abandon() have not been called + Status Finish(); + + // Indicate that the contents of this builder should be abandoned. Stops + // using the file passed to the constructor after this function returns. + // If the caller is not going to call Finish(), it must call Abandon() + // before destroying this builder. + // REQUIRES: Finish(), Abandon() have not been called + void Abandon(); + + // Number of calls to Add() so far. + uint64 NumEntries() const; + + // Size of the file generated so far. If invoked after a successful + // Finish() call, returns the size of the final generated file. + uint64 FileSize() const; + + private: + bool ok() const { return status().ok(); } + void WriteBlock(BlockBuilder* block, BlockHandle* handle); + void WriteRawBlock(const StringPiece& data, CompressionType, + BlockHandle* handle); + + struct Rep; + Rep* rep_; + + // No copying allowed + TableBuilder(const TableBuilder&); + void operator=(const TableBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt new file mode 100644 index 0000000000..7edb9fb121 --- /dev/null +++ b/tensorflow/core/lib/io/table_format.txt @@ -0,0 +1,8 @@ +File format +=========== + +The table format is heavily based on the table format for the LevelDB +open source key/value store, with the exception that our tables +do not support "filter" meta blocks (Bloom Filters). See: + +https://code.google.com/p/leveldb/source/browse/doc/table_format.txt diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h new file mode 100644 index 0000000000..45b061b03b --- /dev/null +++ b/tensorflow/core/lib/io/table_options.h @@ -0,0 +1,53 @@ +#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ +#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ + +#include <stddef.h> + +namespace tensorflow { +namespace table { + +// DB contents are stored in a set of blocks, each of which holds a +// sequence of key,value pairs. Each block may be compressed before +// being stored in a file. The following enum describes which +// compression method (if any) is used to compress a block. +enum CompressionType { + // NOTE: do not change the values of existing entries, as these are + // part of the persistent format on disk. + kNoCompression = 0x0, + kSnappyCompression = 0x1 +}; + +// Options to control the behavior of a table (passed to Table::Open) +struct Options { + // Approximate size of user data packed per block. Note that the + // block size specified here corresponds to uncompressed data. The + // actual size of the unit read from disk may be smaller if + // compression is enabled. This parameter can be changed dynamically. + size_t block_size = 262144; + + // Number of keys between restart points for delta encoding of keys. + // This parameter can be changed dynamically. Most clients should + // leave this parameter alone. + int block_restart_interval = 16; + + // Compress blocks using the specified compression algorithm. This + // parameter can be changed dynamically. + // + // Default: kSnappyCompression, which gives lightweight but fast + // compression. + // + // Typical speeds of kSnappyCompression on an Intel(R) Core(TM)2 2.4GHz: + // ~200-500MB/s compression + // ~400-800MB/s decompression + // Note that these speeds are significantly faster than most + // persistent storage speeds, and therefore it is typically never + // worth switching to kNoCompression. Even if the input data is + // incompressible, the kSnappyCompression implementation will + // efficiently detect that and will switch to uncompressed mode. + CompressionType compression = kSnappyCompression; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc new file mode 100644 index 0000000000..66e90ac64e --- /dev/null +++ b/tensorflow/core/lib/io/table_test.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include <map> +#include <string> +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +namespace test { +static StringPiece RandomString(random::SimplePhilox* rnd, int len, + string* dst) { + dst->resize(len); + for (int i = 0; i < len; i++) { + (*dst)[i] = static_cast<char>(' ' + rnd->Uniform(95)); // ' ' .. '~' + } + return StringPiece(*dst); +} +static string RandomKey(random::SimplePhilox* rnd, int len) { + // Make sure to generate a wide variety of characters so we + // test the boundary conditions for short-key optimizations. + static const char kTestChars[] = {'\0', '\1', 'a', 'b', 'c', + 'd', 'e', '\xfd', '\xfe', '\xff'}; + string result; + for (int i = 0; i < len; i++) { + result += kTestChars[rnd->Uniform(sizeof(kTestChars))]; + } + return result; +} +static StringPiece CompressibleString(random::SimplePhilox* rnd, + double compressed_fraction, size_t len, + string* dst) { + int raw = static_cast<int>(len * compressed_fraction); + if (raw < 1) raw = 1; + string raw_data; + RandomString(rnd, raw, &raw_data); + + // Duplicate the random data until we have filled "len" bytes + dst->clear(); + while (dst->size() < len) { + dst->append(raw_data); + } + dst->resize(len); + return StringPiece(*dst); +} +} + +static void Increment(string* key) { key->push_back('\0'); } + +// An STL comparator that compares two StringPieces +namespace { +struct STLLessThan { + STLLessThan() {} + bool operator()(const string& a, const string& b) const { + return StringPiece(a).compare(StringPiece(b)) < 0; + } +}; +} // namespace + +class StringSink : public WritableFile { + public: + ~StringSink() {} + + const string& contents() const { return contents_; } + + virtual Status Close() { return Status::OK(); } + virtual Status Flush() { return Status::OK(); } + virtual Status Sync() { return Status::OK(); } + + virtual Status Append(const StringPiece& data) { + contents_.append(data.data(), data.size()); + return Status::OK(); + } + + private: + string contents_; +}; + +class StringSource : public RandomAccessFile { + public: + StringSource(const StringPiece& contents) + : contents_(contents.data(), contents.size()), bytes_read_(0) {} + + virtual ~StringSource() {} + + uint64 Size() const { return contents_.size(); } + + virtual Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const { + if (offset > contents_.size()) { + return errors::InvalidArgument("invalid Read offset"); + } + if (offset + n > contents_.size()) { + n = contents_.size() - offset; + } + memcpy(scratch, &contents_[offset], n); + *result = StringPiece(scratch, n); + bytes_read_ += n; + return Status::OK(); + } + + uint64 BytesRead() const { return bytes_read_; } + + private: + string contents_; + mutable uint64 bytes_read_; +}; + +typedef std::map<string, string, STLLessThan> KVMap; + +// Helper class for tests to unify the interface between +// BlockBuilder/TableBuilder and Block/Table. +class Constructor { + public: + explicit Constructor() : data_(STLLessThan()) {} + virtual ~Constructor() {} + + void Add(const string& key, const StringPiece& value) { + data_[key] = value.ToString(); + } + + // Finish constructing the data structure with all the keys that have + // been added so far. Returns the keys in sorted order in "*keys" + // and stores the key/value pairs in "*kvmap" + void Finish(const Options& options, std::vector<string>* keys, KVMap* kvmap) { + *kvmap = data_; + keys->clear(); + for (KVMap::const_iterator it = data_.begin(); it != data_.end(); ++it) { + keys->push_back(it->first); + } + data_.clear(); + Status s = FinishImpl(options, *kvmap); + ASSERT_TRUE(s.ok()) << s.ToString(); + } + + // Construct the data structure from the data in "data" + virtual Status FinishImpl(const Options& options, const KVMap& data) = 0; + + virtual Iterator* NewIterator() const = 0; + + virtual const KVMap& data() { return data_; } + + private: + KVMap data_; +}; + +class BlockConstructor : public Constructor { + public: + BlockConstructor() : block_(NULL) {} + ~BlockConstructor() { delete block_; } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + delete block_; + block_ = NULL; + BlockBuilder builder(&options); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + } + // Open the block + data_ = builder.Finish().ToString(); + BlockContents contents; + contents.data = data_; + contents.cachable = false; + contents.heap_allocated = false; + block_ = new Block(contents); + return Status::OK(); + } + virtual Iterator* NewIterator() const { return block_->NewIterator(); } + + private: + string data_; + Block* block_; +}; + +class TableConstructor : public Constructor { + public: + TableConstructor() : source_(NULL), table_(NULL) {} + ~TableConstructor() { Reset(); } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + Reset(); + StringSink sink; + TableBuilder builder(options, &sink); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + TF_CHECK_OK(builder.status()); + } + Status s = builder.Finish(); + TF_CHECK_OK(s) << s.ToString(); + + CHECK_EQ(sink.contents().size(), builder.FileSize()); + + // Open the table + source_ = new StringSource(sink.contents()); + Options table_options; + return Table::Open(table_options, source_, sink.contents().size(), &table_); + } + + virtual Iterator* NewIterator() const { return table_->NewIterator(); } + + uint64 ApproximateOffsetOf(const StringPiece& key) const { + return table_->ApproximateOffsetOf(key); + } + + uint64 BytesRead() const { return source_->BytesRead(); } + + private: + void Reset() { + delete table_; + delete source_; + table_ = NULL; + source_ = NULL; + } + + StringSource* source_; + Table* table_; +}; + +enum TestType { TABLE_TEST, BLOCK_TEST }; + +struct TestArgs { + TestType type; + int restart_interval; +}; + +static const TestArgs kTestArgList[] = { + {TABLE_TEST, 16}, {TABLE_TEST, 1}, {TABLE_TEST, 1024}, + {BLOCK_TEST, 16}, {BLOCK_TEST, 1}, {BLOCK_TEST, 1024}, +}; +static const int kNumTestArgs = sizeof(kTestArgList) / sizeof(kTestArgList[0]); + +class Harness : public ::testing::Test { + public: + Harness() : constructor_(NULL) {} + + void Init(const TestArgs& args) { + delete constructor_; + constructor_ = NULL; + options_ = Options(); + + options_.block_restart_interval = args.restart_interval; + // Use shorter block size for tests to exercise block boundary + // conditions more. + options_.block_size = 256; + switch (args.type) { + case TABLE_TEST: + constructor_ = new TableConstructor(); + break; + case BLOCK_TEST: + constructor_ = new BlockConstructor(); + break; + } + } + + ~Harness() { delete constructor_; } + + void Add(const string& key, const string& value) { + constructor_->Add(key, value); + } + + void Test(random::SimplePhilox* rnd) { + std::vector<string> keys; + KVMap data; + constructor_->Finish(options_, &keys, &data); + + TestForwardScan(keys, data); + TestRandomAccess(rnd, keys, data); + } + + void TestForwardScan(const std::vector<string>& keys, const KVMap& data) { + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + iter->SeekToFirst(); + for (KVMap::const_iterator model_iter = data.begin(); + model_iter != data.end(); ++model_iter) { + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + iter->Next(); + } + ASSERT_TRUE(!iter->Valid()); + delete iter; + } + + void TestRandomAccess(random::SimplePhilox* rnd, + const std::vector<string>& keys, const KVMap& data) { + static const bool kVerbose = false; + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + KVMap::const_iterator model_iter = data.begin(); + if (kVerbose) fprintf(stderr, "---\n"); + for (int i = 0; i < 200; i++) { + const int toss = rnd->Uniform(3); + switch (toss) { + case 0: { + if (iter->Valid()) { + if (kVerbose) fprintf(stderr, "Next\n"); + iter->Next(); + ++model_iter; + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + } + break; + } + + case 1: { + if (kVerbose) fprintf(stderr, "SeekToFirst\n"); + iter->SeekToFirst(); + model_iter = data.begin(); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + + case 2: { + string key = PickRandomKey(rnd, keys); + model_iter = data.lower_bound(key); + if (kVerbose) + fprintf(stderr, "Seek '%s'\n", str_util::CEscape(key).c_str()); + iter->Seek(StringPiece(key)); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + } + } + delete iter; + } + + string ToString(const KVMap& data, const KVMap::const_iterator& it) { + if (it == data.end()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const KVMap& data, const KVMap::const_reverse_iterator& it) { + if (it == data.rend()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const Iterator* it) { + if (!it->Valid()) { + return "END"; + } else { + return "'" + it->key().ToString() + "->" + it->value().ToString() + "'"; + } + } + + string PickRandomKey(random::SimplePhilox* rnd, + const std::vector<string>& keys) { + if (keys.empty()) { + return "foo"; + } else { + const int index = rnd->Uniform(keys.size()); + string result = keys[index]; + switch (rnd->Uniform(3)) { + case 0: + // Return an existing key + break; + case 1: { + // Attempt to return something smaller than an existing key + if (result.size() > 0 && result[result.size() - 1] > '\0') { + result[result.size() - 1]--; + } + break; + } + case 2: { + // Return something larger than an existing key + Increment(&result); + break; + } + } + return result; + } + } + + private: + Options options_; + Constructor* constructor_; +}; + +// Test empty table/block. +TEST_F(Harness, Empty) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Test(&rnd); + } +} + +// Special test for a block with no restart entries. The C++ leveldb +// code never generates such blocks, but the Java version of leveldb +// seems to. +TEST_F(Harness, ZeroRestartPointsInBlock) { + char data[sizeof(uint32)]; + memset(data, 0, sizeof(data)); + BlockContents contents; + contents.data = StringPiece(data, sizeof(data)); + contents.cachable = false; + contents.heap_allocated = false; + Block block(contents); + Iterator* iter = block.NewIterator(); + iter->SeekToFirst(); + ASSERT_TRUE(!iter->Valid()); + iter->Seek("foo"); + ASSERT_TRUE(!iter->Valid()); + delete iter; +} + +// Test the empty key +TEST_F(Harness, SimpleEmptyKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Add("", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSingle) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 2, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMulti) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Add("abcd", "v"); + Add("ac", "v2"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMultiBigValues) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("ainitial", "tiny"); + Add("anext", string(10000000, 'a')); + Add("anext2", string(10000000, 'b')); + Add("azz", "tiny"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSpecialKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 4, 17); + random::SimplePhilox rnd(&philox); + Add("\xff\xff", "v3"); + Test(&rnd); + } +} + +TEST_F(Harness, Randomized) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 5, 17); + random::SimplePhilox rnd(&philox); + for (int num_entries = 0; num_entries < 2000; + num_entries += (num_entries < 50 ? 1 : 200)) { + if ((num_entries % 10) == 0) { + fprintf(stderr, "case %d of %d: num_entries = %d\n", (i + 1), + int(kNumTestArgs), num_entries); + } + for (int e = 0; e < num_entries; e++) { + string v; + Add(test::RandomKey(&rnd, rnd.Skewed(4)), + test::RandomString(&rnd, rnd.Skewed(5), &v).ToString()); + } + Test(&rnd); + } + } +} + +static bool Between(uint64 val, uint64 low, uint64 high) { + bool result = (val >= low) && (val <= high); + if (!result) { + fprintf(stderr, "Value %llu is not in range [%llu, %llu]\n", + (unsigned long long)(val), (unsigned long long)(low), + (unsigned long long)(high)); + } + return result; +} + +class TableTest {}; + +TEST(TableTest, ApproximateOffsetOfPlain) { + TableConstructor c; + c.Add("k01", "hello"); + c.Add("k02", "hello2"); + c.Add("k03", string(10000, 'x')); + c.Add("k04", string(200000, 'x')); + c.Add("k05", string(300000, 'x')); + c.Add("k06", "hello3"); + c.Add("k07", string(100000, 'x')); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01a"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 10, 500)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 10000, 11000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04a"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k05"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k06"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k07"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 610000, 612000)); +} + +static bool SnappyCompressionSupported() { + string out; + StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + return port::Snappy_Compress(in.data(), in.size(), &out); +} + +TEST(TableTest, ApproximateOffsetOfCompressed) { + if (!SnappyCompressionSupported()) { + fprintf(stderr, "skipping compression tests\n"); + return; + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + TableConstructor c; + string tmp; + c.Add("k01", "hello"); + c.Add("k02", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + c.Add("k03", "hello3"); + c.Add("k04", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kSnappyCompression; + c.Finish(options, &keys, &kvmap); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 10, 100)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 4000, 6000)); +} + +TEST(TableTest, SeekToFirstKeyDoesNotReadTooMuch) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string tmp; + TableConstructor c; + c.Add("k01", "firstvalue"); + c.Add("k03", test::CompressibleString(&rnd, 0.25, 1000000, &tmp)); + c.Add("k04", "abc"); + std::vector<string> keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + Iterator* iter = c.NewIterator(); + iter->Seek("k01"); + delete iter; + // Make sure we don't read the big second block when just trying to + // retrieve the data in the first key + EXPECT_LT(c.BytesRead(), 200); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc new file mode 100644 index 0000000000..409baade6d --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/two_level_iterator.h" + +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +namespace { + +typedef Iterator* (*BlockFunction)(void*, const StringPiece&); + +class TwoLevelIterator : public Iterator { + public: + TwoLevelIterator(Iterator* index_iter, BlockFunction block_function, + void* arg); + + virtual ~TwoLevelIterator(); + + virtual void Seek(const StringPiece& target); + virtual void SeekToFirst(); + virtual void Next(); + + virtual bool Valid() const { + return (data_iter_ == nullptr) ? false : data_iter_->Valid(); + } + virtual StringPiece key() const { + assert(Valid()); + return data_iter_->key(); + } + virtual StringPiece value() const { + assert(Valid()); + return data_iter_->value(); + } + virtual Status status() const { + // It'd be nice if status() returned a const Status& instead of a + // Status + if (!index_iter_->status().ok()) { + return index_iter_->status(); + } else if (data_iter_ != NULL && !data_iter_->status().ok()) { + return data_iter_->status(); + } else { + return status_; + } + } + + private: + void SaveError(const Status& s) { + if (status_.ok() && !s.ok()) status_ = s; + } + void SkipEmptyDataBlocksForward(); + void SetDataIterator(Iterator* data_iter); + void InitDataBlock(); + + BlockFunction block_function_; + void* arg_; + Status status_; + Iterator* index_iter_; + Iterator* data_iter_; // May be NULL + // If data_iter_ is non-NULL, then "data_block_handle_" holds the + // "index_value" passed to block_function_ to create the data_iter_. + string data_block_handle_; +}; + +TwoLevelIterator::TwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) + : block_function_(block_function), + arg_(arg), + index_iter_(index_iter), + data_iter_(NULL) {} + +TwoLevelIterator::~TwoLevelIterator() { + delete index_iter_; + delete data_iter_; +} + +void TwoLevelIterator::Seek(const StringPiece& target) { + index_iter_->Seek(target); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->Seek(target); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SeekToFirst() { + index_iter_->SeekToFirst(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::Next() { + assert(Valid()); + data_iter_->Next(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SkipEmptyDataBlocksForward() { + while (data_iter_ == NULL || !data_iter_->Valid()) { + // Move to next block + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + return; + } + index_iter_->Next(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + } +} + +void TwoLevelIterator::SetDataIterator(Iterator* data_iter) { + if (data_iter_ != NULL) { + SaveError(data_iter_->status()); + delete data_iter_; + } + data_iter_ = data_iter; +} + +void TwoLevelIterator::InitDataBlock() { + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + } else { + StringPiece handle = index_iter_->value(); + if (data_iter_ != NULL && handle.compare(data_block_handle_) == 0) { + // data_iter_ is already constructed with this iterator, so + // no need to change anything + } else { + Iterator* iter = (*block_function_)(arg_, handle); + data_block_handle_.assign(handle.data(), handle.size()); + SetDataIterator(iter); + } + } +} + +} // namespace + +Iterator* NewTwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) { + return new TwoLevelIterator(index_iter, block_function, arg); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h new file mode 100644 index 0000000000..1cc5d2f921 --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -0,0 +1,30 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +// Return a new two level iterator. A two-level iterator contains an +// index iterator whose values point to a sequence of blocks where +// each block is itself a sequence of key,value pairs. The returned +// two-level iterator yields the concatenation of all key/value pairs +// in the sequence of blocks. Takes ownership of "index_iter" and +// will delete it when no longer needed. +// +// Uses a supplied function to convert an index_iter value into +// an iterator over the contents of the corresponding block. +extern Iterator* NewTwoLevelIterator( + Iterator* index_iter, + Iterator* (*block_function)(void* arg, const StringPiece& index_value), + void* arg); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ |