diff options
Diffstat (limited to 'tensorflow/core/lib/gtl/flatset.h')
-rw-r--r-- | tensorflow/core/lib/gtl/flatset.h | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h new file mode 100644 index 0000000000..b94d88cbc6 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatset.h @@ -0,0 +1,277 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ + +#include <stddef.h> +#include <utility> +#include "tensorflow/core/lib/gtl/flatrep.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { + +// FlatSet<K,...> provides a set of K. +// +// The map is implemented using an open-addressed hash table. A +// single array holds entire map contents and collisions are resolved +// by probing at a sequence of locations in the array. +template <typename Key, class Hash, class Eq = std::equal_to<Key>> +class FlatSet { + private: + // Forward declare some internal types needed in public section. + struct Bucket; + + public: + typedef Key key_type; + typedef Key value_type; + typedef Hash hasher; + typedef Eq key_equal; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + + FlatSet() : FlatSet(1) {} + + explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) + : rep_(N, hf, eq) {} + + FlatSet(const FlatSet& src) : rep_(src.rep_) {} + + template <typename InputIter> + FlatSet(InputIter first, InputIter last, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatSet(N, hf, eq) { + insert(first, last); + } + + FlatSet& operator=(const FlatSet& src) { + rep_.CopyFrom(src.rep_); + return *this; + } + + ~FlatSet() {} + + void swap(FlatSet& x) { rep_.swap(x.rep_); } + void clear_no_resize() { rep_.clear_no_resize(); } + void clear() { rep_.clear(); } + void reserve(size_t N) { rep_.Resize(std::max(N, size())); } + void rehash(size_t N) { rep_.Resize(std::max(N, size())); } + void resize(size_t N) { rep_.Resize(std::max(N, size())); } + size_t size() const { return rep_.size(); } + bool empty() const { return size() == 0; } + size_t bucket_count() const { return rep_.bucket_count(); } + hasher hash_function() const { return rep_.hash_function(); } + key_equal key_eq() const { return rep_.key_eq(); } + + class iterator { + public: + iterator() : b_(nullptr), end_(nullptr), i_(0) {} + + // Make iterator pointing at first element at or after b. + explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { + SkipUnused(); + } + + // Make iterator pointing exactly at ith element in b, which must exist. + iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {} + + Key& operator*() { return key(); } + Key* operator->() { return &key(); } + bool operator==(const iterator& x) const { + return b_ == x.b_ && i_ == x.i_; + } + bool operator!=(const iterator& x) const { return !(*this == x); } + iterator& operator++() { + DCHECK(b_ != end_); + i_++; + SkipUnused(); + return *this; + } + + private: + friend class FlatSet; + Bucket* b_; + Bucket* end_; + uint32 i_; + + Key& key() const { return b_->key(i_); } + void SkipUnused() { + while (b_ < end_) { + if (i_ >= Rep::kWidth) { + i_ = 0; + b_++; + } else if (b_->marker[i_] < 2) { + i_++; + } else { + break; + } + } + } + }; + + class const_iterator { + private: + mutable iterator rep_; // Share state and logic with non-const iterator. + public: + const_iterator() : rep_() {} + explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} + const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} + + const Key& operator*() const { return rep_.key(); } + const Key* operator->() const { return &rep_.key(); } + bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } + bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } + const_iterator& operator++() { + ++rep_; + return *this; + } + }; + + iterator begin() { return iterator(rep_.start(), rep_.limit()); } + iterator end() { return iterator(rep_.limit(), rep_.limit()); } + const_iterator begin() const { + return const_iterator(rep_.start(), rep_.limit()); + } + const_iterator end() const { + return const_iterator(rep_.limit(), rep_.limit()); + } + + size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } + iterator find(const Key& k) { + auto r = rep_.Find(k); + return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); + } + const_iterator find(const Key& k) const { + auto r = rep_.Find(k); + return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); + } + + std::pair<iterator, bool> insert(const Key& k) { return Insert(k); } + template <typename InputIter> + void insert(InputIter first, InputIter last) { + for (; first != last; ++first) { + insert(*first); + } + } + + template <typename... Args> + std::pair<iterator, bool> emplace(Args&&... args) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward<Args>(args)...); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + size_t erase(const Key& k) { + auto r = rep_.Find(k); + if (!r.found) return 0; + rep_.Erase(r.b, r.index); + return 1; + } + iterator erase(iterator pos) { + rep_.Erase(pos.b_, pos.i_); + ++pos; + return pos; + } + iterator erase(iterator pos, iterator last) { + for (; pos != last; ++pos) { + rep_.Erase(pos.b_, pos.i_); + } + return pos; + } + + std::pair<iterator, iterator> equal_range(const Key& k) { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + + bool operator==(const FlatSet& x) const { + if (size() != x.size()) return false; + for (const auto& elem : x) { + auto i = find(elem); + if (i == end()) return false; + } + return true; + } + bool operator!=(const FlatSet& x) const { return !(*this == x); } + + // If key exists in the table, prefetch it. This is a hint, and may + // have no effect. + void prefetch_value(const Key& key) const { rep_.Prefetch(key); } + + private: + using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; + + // Bucket stores kWidth <marker, key, value> triples. + // The data is organized as three parallel arrays to reduce padding. + struct Bucket { + uint8 marker[Rep::kWidth]; + + // Wrap keys in union to control construction and destruction. + union Storage { + Key key[Rep::kWidth]; + Storage() {} + ~Storage() {} + } storage; + + Key& key(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.key[i]; + } + void Destroy(uint32 i) { storage.key[i].Key::~Key(); } + void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); + } + void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(src->storage.key[src_index]); + } + }; + + std::pair<iterator, bool> Insert(const Key& k) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(k); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + Rep rep_; +}; + +} // namespace gtl +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ |