diff options
Diffstat (limited to 'tensorflow/core/lib/io/block.cc')
-rw-r--r-- | tensorflow/core/lib/io/block.cc | 236 |
1 files changed, 236 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc new file mode 100644 index 0000000000..1ddaa2eb78 --- /dev/null +++ b/tensorflow/core/lib/io/block.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Decodes the blocks generated by block_builder.cc. + +#include "tensorflow/core/lib/io/block.h" + +#include <vector> +#include <algorithm> +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +inline uint32 Block::NumRestarts() const { + assert(size_ >= sizeof(uint32)); + return core::DecodeFixed32(data_ + size_ - sizeof(uint32)); +} + +Block::Block(const BlockContents& contents) + : data_(contents.data.data()), + size_(contents.data.size()), + owned_(contents.heap_allocated) { + if (size_ < sizeof(uint32)) { + size_ = 0; // Error marker + } else { + size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32); + if (NumRestarts() > max_restarts_allowed) { + // The size is too small for NumRestarts() + size_ = 0; + } else { + restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32); + } + } +} + +Block::~Block() { + if (owned_) { + delete[] data_; + } +} + +// Helper routine: decode the next block entry starting at "p", +// storing the number of shared key bytes, non_shared key bytes, +// and the length of the value in "*shared", "*non_shared", and +// "*value_length", respectively. Will not dereference past "limit". +// +// If any errors are detected, returns NULL. Otherwise, returns a +// pointer to the key delta (just past the three decoded values). +static inline const char* DecodeEntry(const char* p, const char* limit, + uint32* shared, uint32* non_shared, + uint32* value_length) { + if (limit - p < 3) return NULL; + *shared = reinterpret_cast<const unsigned char*>(p)[0]; + *non_shared = reinterpret_cast<const unsigned char*>(p)[1]; + *value_length = reinterpret_cast<const unsigned char*>(p)[2]; + if ((*shared | *non_shared | *value_length) < 128) { + // Fast path: all three values are encoded in one byte each + p += 3; + } else { + if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL; + } + + if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) { + return NULL; + } + return p; +} + +class Block::Iter : public Iterator { + private: + const char* const data_; // underlying block contents + uint32 const restarts_; // Offset of restart array (list of fixed32) + uint32 const num_restarts_; // Number of uint32 entries in restart array + + // current_ is offset in data_ of current entry. >= restarts_ if !Valid + uint32 current_; + uint32 restart_index_; // Index of restart block in which current_ falls + string key_; + StringPiece value_; + Status status_; + + inline int Compare(const StringPiece& a, const StringPiece& b) const { + return a.compare(b); + } + + // Return the offset in data_ just past the end of the current entry. + inline uint32 NextEntryOffset() const { + return (value_.data() + value_.size()) - data_; + } + + uint32 GetRestartPoint(uint32 index) { + assert(index < num_restarts_); + return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32)); + } + + void SeekToRestartPoint(uint32 index) { + key_.clear(); + restart_index_ = index; + // current_ will be fixed by ParseNextKey(); + + // ParseNextKey() starts at the end of value_, so set value_ accordingly + uint32 offset = GetRestartPoint(index); + value_ = StringPiece(data_ + offset, 0); + } + + public: + Iter(const char* data, uint32 restarts, uint32 num_restarts) + : data_(data), + restarts_(restarts), + num_restarts_(num_restarts), + current_(restarts_), + restart_index_(num_restarts_) { + assert(num_restarts_ > 0); + } + + virtual bool Valid() const { return current_ < restarts_; } + virtual Status status() const { return status_; } + virtual StringPiece key() const { + assert(Valid()); + return key_; + } + virtual StringPiece value() const { + assert(Valid()); + return value_; + } + + virtual void Next() { + assert(Valid()); + ParseNextKey(); + } + + virtual void Seek(const StringPiece& target) { + // Binary search in restart array to find the last restart point + // with a key < target + uint32 left = 0; + uint32 right = num_restarts_ - 1; + while (left < right) { + uint32 mid = (left + right + 1) / 2; + uint32 region_offset = GetRestartPoint(mid); + uint32 shared, non_shared, value_length; + const char* key_ptr = + DecodeEntry(data_ + region_offset, data_ + restarts_, &shared, + &non_shared, &value_length); + if (key_ptr == NULL || (shared != 0)) { + CorruptionError(); + return; + } + StringPiece mid_key(key_ptr, non_shared); + if (Compare(mid_key, target) < 0) { + // Key at "mid" is smaller than "target". Therefore all + // blocks before "mid" are uninteresting. + left = mid; + } else { + // Key at "mid" is >= "target". Therefore all blocks at or + // after "mid" are uninteresting. + right = mid - 1; + } + } + + // Linear search (within restart block) for first key >= target + SeekToRestartPoint(left); + while (true) { + if (!ParseNextKey()) { + return; + } + if (Compare(key_, target) >= 0) { + return; + } + } + } + + virtual void SeekToFirst() { + SeekToRestartPoint(0); + ParseNextKey(); + } + + private: + void CorruptionError() { + current_ = restarts_; + restart_index_ = num_restarts_; + status_ = errors::DataLoss("bad entry in block"); + key_.clear(); + value_.clear(); + } + + bool ParseNextKey() { + current_ = NextEntryOffset(); + const char* p = data_ + current_; + const char* limit = data_ + restarts_; // Restarts come right after data + if (p >= limit) { + // No more entries to return. Mark as invalid. + current_ = restarts_; + restart_index_ = num_restarts_; + return false; + } + + // Decode next entry + uint32 shared, non_shared, value_length; + p = DecodeEntry(p, limit, &shared, &non_shared, &value_length); + if (p == NULL || key_.size() < shared) { + CorruptionError(); + return false; + } else { + key_.resize(shared); + key_.append(p, non_shared); + value_ = StringPiece(p + non_shared, value_length); + while (restart_index_ + 1 < num_restarts_ && + GetRestartPoint(restart_index_ + 1) < current_) { + ++restart_index_; + } + return true; + } + } +}; + +Iterator* Block::NewIterator() { + if (size_ < sizeof(uint32)) { + return NewErrorIterator(errors::DataLoss("bad block contents")); + } + const uint32 num_restarts = NumRestarts(); + if (num_restarts == 0) { + return NewEmptyIterator(); + } else { + return new Iter(data_, restart_offset_, num_restarts); + } +} + +} // namespace table +} // namespace tensorflow |