aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io/block.cc
blob: 1ddaa2eb78868f24fd794fc6b63e0e6384248f63 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.
//
// Decodes the blocks generated by block_builder.cc.

#include "tensorflow/core/lib/io/block.h"

#include <vector>
#include <algorithm>
#include "tensorflow/core/lib/io/format.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
namespace table {

inline uint32 Block::NumRestarts() const {
  assert(size_ >= sizeof(uint32));
  return core::DecodeFixed32(data_ + size_ - sizeof(uint32));
}

Block::Block(const BlockContents& contents)
    : data_(contents.data.data()),
      size_(contents.data.size()),
      owned_(contents.heap_allocated) {
  if (size_ < sizeof(uint32)) {
    size_ = 0;  // Error marker
  } else {
    size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32);
    if (NumRestarts() > max_restarts_allowed) {
      // The size is too small for NumRestarts()
      size_ = 0;
    } else {
      restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32);
    }
  }
}

Block::~Block() {
  if (owned_) {
    delete[] data_;
  }
}

// Helper routine: decode the next block entry starting at "p",
// storing the number of shared key bytes, non_shared key bytes,
// and the length of the value in "*shared", "*non_shared", and
// "*value_length", respectively.  Will not dereference past "limit".
//
// If any errors are detected, returns NULL.  Otherwise, returns a
// pointer to the key delta (just past the three decoded values).
static inline const char* DecodeEntry(const char* p, const char* limit,
                                      uint32* shared, uint32* non_shared,
                                      uint32* value_length) {
  if (limit - p < 3) return NULL;
  *shared = reinterpret_cast<const unsigned char*>(p)[0];
  *non_shared = reinterpret_cast<const unsigned char*>(p)[1];
  *value_length = reinterpret_cast<const unsigned char*>(p)[2];
  if ((*shared | *non_shared | *value_length) < 128) {
    // Fast path: all three values are encoded in one byte each
    p += 3;
  } else {
    if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL;
    if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL;
    if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL;
  }

  if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) {
    return NULL;
  }
  return p;
}

class Block::Iter : public Iterator {
 private:
  const char* const data_;     // underlying block contents
  uint32 const restarts_;      // Offset of restart array (list of fixed32)
  uint32 const num_restarts_;  // Number of uint32 entries in restart array

  // current_ is offset in data_ of current entry.  >= restarts_ if !Valid
  uint32 current_;
  uint32 restart_index_;  // Index of restart block in which current_ falls
  string key_;
  StringPiece value_;
  Status status_;

  inline int Compare(const StringPiece& a, const StringPiece& b) const {
    return a.compare(b);
  }

  // Return the offset in data_ just past the end of the current entry.
  inline uint32 NextEntryOffset() const {
    return (value_.data() + value_.size()) - data_;
  }

  uint32 GetRestartPoint(uint32 index) {
    assert(index < num_restarts_);
    return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32));
  }

  void SeekToRestartPoint(uint32 index) {
    key_.clear();
    restart_index_ = index;
    // current_ will be fixed by ParseNextKey();

    // ParseNextKey() starts at the end of value_, so set value_ accordingly
    uint32 offset = GetRestartPoint(index);
    value_ = StringPiece(data_ + offset, 0);
  }

 public:
  Iter(const char* data, uint32 restarts, uint32 num_restarts)
      : data_(data),
        restarts_(restarts),
        num_restarts_(num_restarts),
        current_(restarts_),
        restart_index_(num_restarts_) {
    assert(num_restarts_ > 0);
  }

  virtual bool Valid() const { return current_ < restarts_; }
  virtual Status status() const { return status_; }
  virtual StringPiece key() const {
    assert(Valid());
    return key_;
  }
  virtual StringPiece value() const {
    assert(Valid());
    return value_;
  }

  virtual void Next() {
    assert(Valid());
    ParseNextKey();
  }

  virtual void Seek(const StringPiece& target) {
    // Binary search in restart array to find the last restart point
    // with a key < target
    uint32 left = 0;
    uint32 right = num_restarts_ - 1;
    while (left < right) {
      uint32 mid = (left + right + 1) / 2;
      uint32 region_offset = GetRestartPoint(mid);
      uint32 shared, non_shared, value_length;
      const char* key_ptr =
          DecodeEntry(data_ + region_offset, data_ + restarts_, &shared,
                      &non_shared, &value_length);
      if (key_ptr == NULL || (shared != 0)) {
        CorruptionError();
        return;
      }
      StringPiece mid_key(key_ptr, non_shared);
      if (Compare(mid_key, target) < 0) {
        // Key at "mid" is smaller than "target".  Therefore all
        // blocks before "mid" are uninteresting.
        left = mid;
      } else {
        // Key at "mid" is >= "target".  Therefore all blocks at or
        // after "mid" are uninteresting.
        right = mid - 1;
      }
    }

    // Linear search (within restart block) for first key >= target
    SeekToRestartPoint(left);
    while (true) {
      if (!ParseNextKey()) {
        return;
      }
      if (Compare(key_, target) >= 0) {
        return;
      }
    }
  }

  virtual void SeekToFirst() {
    SeekToRestartPoint(0);
    ParseNextKey();
  }

 private:
  void CorruptionError() {
    current_ = restarts_;
    restart_index_ = num_restarts_;
    status_ = errors::DataLoss("bad entry in block");
    key_.clear();
    value_.clear();
  }

  bool ParseNextKey() {
    current_ = NextEntryOffset();
    const char* p = data_ + current_;
    const char* limit = data_ + restarts_;  // Restarts come right after data
    if (p >= limit) {
      // No more entries to return.  Mark as invalid.
      current_ = restarts_;
      restart_index_ = num_restarts_;
      return false;
    }

    // Decode next entry
    uint32 shared, non_shared, value_length;
    p = DecodeEntry(p, limit, &shared, &non_shared, &value_length);
    if (p == NULL || key_.size() < shared) {
      CorruptionError();
      return false;
    } else {
      key_.resize(shared);
      key_.append(p, non_shared);
      value_ = StringPiece(p + non_shared, value_length);
      while (restart_index_ + 1 < num_restarts_ &&
             GetRestartPoint(restart_index_ + 1) < current_) {
        ++restart_index_;
      }
      return true;
    }
  }
};

Iterator* Block::NewIterator() {
  if (size_ < sizeof(uint32)) {
    return NewErrorIterator(errors::DataLoss("bad block contents"));
  }
  const uint32 num_restarts = NumRestarts();
  if (num_restarts == 0) {
    return NewEmptyIterator();
  } else {
    return new Iter(data_, restart_offset_, num_restarts);
  }
}

}  // namespace table
}  // namespace tensorflow