diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-27 08:10:54 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-27 09:21:12 -0700 |
commit | 80aec93166dadb2dc30250e1251ab3eb006c2d53 (patch) | |
tree | c34632f2f255eda56214559ff7d7d54878a6a6c9 | |
parent | e43eaf662db492c909e6cab8c954178b75f7b63d (diff) |
Added new tensorflow::gtl::FlatMap and tensorflow::gtl::FlatSet classes.
Mostly drop-in replacements for std::unordered_map and std::unordered_set,
but much faster (does not do an allocation per entry, and represents
entries in groups of 8 in a flat array, which is much more cache efficient).
Benchmarks not included in this cl show about 3X to 5X performance
improvements over the std::unordered_{set,map} for many kinds of
common maps e.g. std::unordered_mapmap<int64, int64> or
std::unordered_map<string, int64>.
Change: 137401863
-rw-r--r-- | tensorflow/core/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/core/lib/gtl/flatmap.h | 349 | ||||
-rw-r--r-- | tensorflow/core/lib/gtl/flatmap_test.cc | 576 | ||||
-rw-r--r-- | tensorflow/core/lib/gtl/flatrep.h | 332 | ||||
-rw-r--r-- | tensorflow/core/lib/gtl/flatset.h | 277 | ||||
-rw-r--r-- | tensorflow/core/lib/gtl/flatset_test.cc | 501 | ||||
-rw-r--r-- | tensorflow/core/lib/hash/hash.h | 18 |
7 files changed, 2057 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3eea01363b..76e6ee7568 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -164,6 +164,8 @@ cc_library( "lib/core/threadpool.h", "lib/gtl/array_slice.h", "lib/gtl/cleanup.h", + "lib/gtl/flatmap.h", + "lib/gtl/flatset.h", "lib/gtl/inlined_vector.h", "lib/gtl/priority_queue_util.h", "lib/hash/crc32c.h", @@ -1447,6 +1449,8 @@ tf_cc_tests( "lib/gtl/array_slice_test.cc", "lib/gtl/cleanup_test.cc", "lib/gtl/edit_distance_test.cc", + "lib/gtl/flatmap_test.cc", + "lib/gtl/flatset_test.cc", "lib/gtl/inlined_vector_test.cc", "lib/gtl/int_type_test.cc", "lib/gtl/iterator_range_test.cc", diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h new file mode 100644 index 0000000000..c66bc47168 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatmap.h @@ -0,0 +1,349 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ + +#include <stddef.h> +#include <utility> +#include "tensorflow/core/lib/gtl/flatrep.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { + +// FlatMap<K,V,...> provides a map from K to V. +// +// The map is implemented using an open-addressed hash table. A +// single array holds entire map contents and collisions are resolved +// by probing at a sequence of locations in the array. +template <typename Key, typename Val, class Hash, class Eq = std::equal_to<Key>> +class FlatMap { + private: + // Forward declare some internal types needed in public section. + struct Bucket; + + public: + typedef Key key_type; + typedef Val mapped_type; + typedef Hash hasher; + typedef Eq key_equal; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + // We cannot use std::pair<> since internal representation stores + // keys and values in separate arrays, so we make a custom struct + // that holds references to the internal key, value elements. + struct value_type { + typedef Key first_type; + typedef Val second_type; + + const Key& first; + Val& second; + value_type(const Key& k, Val& v) : first(k), second(v) {} + }; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + + FlatMap() : FlatMap(1) {} + + explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) + : rep_(N, hf, eq) {} + + FlatMap(const FlatMap& src) : rep_(src.rep_) {} + + template <typename InputIter> + FlatMap(InputIter first, InputIter last, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatMap(N, hf, eq) { + insert(first, last); + } + + FlatMap& operator=(const FlatMap& src) { + rep_.CopyFrom(src.rep_); + return *this; + } + + ~FlatMap() {} + + void swap(FlatMap& x) { rep_.swap(x.rep_); } + void clear_no_resize() { rep_.clear_no_resize(); } + void clear() { rep_.clear(); } + void reserve(size_t N) { rep_.Resize(std::max(N, size())); } + void rehash(size_t N) { rep_.Resize(std::max(N, size())); } + void resize(size_t N) { rep_.Resize(std::max(N, size())); } + size_t size() const { return rep_.size(); } + bool empty() const { return size() == 0; } + size_t bucket_count() const { return rep_.bucket_count(); } + hasher hash_function() const { return rep_.hash_function(); } + key_equal key_eq() const { return rep_.key_eq(); } + + class iterator { + public: + iterator() : b_(nullptr), end_(nullptr), i_(0) {} + + // Make iterator pointing at first element at or after b. + explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { + SkipUnused(); + } + + // Make iterator pointing exactly at ith element in b, which must exist. + iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) { + FillValue(); + } + + value_type& operator*() { return *val(); } + value_type* operator->() { return val(); } + bool operator==(const iterator& x) const { + return b_ == x.b_ && i_ == x.i_; + } + bool operator!=(const iterator& x) const { return !(*this == x); } + iterator& operator++() { + DCHECK(b_ != end_); + i_++; + SkipUnused(); + return *this; + } + + private: + friend class FlatMap; + Bucket* b_; + Bucket* end_; + uint32 i_; + char space_[sizeof(value_type)]; + + value_type* val() { return reinterpret_cast<value_type*>(space_); } + void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); } + void SkipUnused() { + while (b_ < end_) { + if (i_ >= Rep::kWidth) { + i_ = 0; + b_++; + } else if (b_->marker[i_] < 2) { + i_++; + } else { + FillValue(); + break; + } + } + } + }; + + class const_iterator { + private: + mutable iterator rep_; // Share state and logic with non-const iterator. + public: + const_iterator() : rep_() {} + explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} + const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} + + const value_type& operator*() const { return *rep_.val(); } + const value_type* operator->() const { return rep_.val(); } + bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } + bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } + const_iterator& operator++() { + ++rep_; + return *this; + } + }; + + iterator begin() { return iterator(rep_.start(), rep_.limit()); } + iterator end() { return iterator(rep_.limit(), rep_.limit()); } + const_iterator begin() const { + return const_iterator(rep_.start(), rep_.limit()); + } + const_iterator end() const { + return const_iterator(rep_.limit(), rep_.limit()); + } + + size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } + iterator find(const Key& k) { + auto r = rep_.Find(k); + return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); + } + const_iterator find(const Key& k) const { + auto r = rep_.Find(k); + return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); + } + + Val& at(const Key& k) { + auto r = rep_.Find(k); + DCHECK(r.found); + return r.b->val(r.index); + } + const Val& at(const Key& k) const { + auto r = rep_.Find(k); + DCHECK(r.found); + return r.b->val(r.index); + } + + template <typename P> + std::pair<iterator, bool> insert(const P& p) { + return Insert(p.first, p.second); + } + std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) { + return Insert(p.first, p.second); + } + template <typename InputIter> + void insert(InputIter first, InputIter last) { + for (; first != last; ++first) { + insert(*first); + } + } + + Val& operator[](const Key& k) { return IndexOp(k); } + Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); } + + template <typename... Args> + std::pair<iterator, bool> emplace(Args&&... args) { + return InsertPair(std::make_pair(std::forward<Args>(args)...)); + } + + size_t erase(const Key& k) { + auto r = rep_.Find(k); + if (!r.found) return 0; + rep_.Erase(r.b, r.index); + return 1; + } + iterator erase(iterator pos) { + rep_.Erase(pos.b_, pos.i_); + ++pos; + return pos; + } + iterator erase(iterator pos, iterator last) { + for (; pos != last; ++pos) { + rep_.Erase(pos.b_, pos.i_); + } + return pos; + } + + std::pair<iterator, iterator> equal_range(const Key& k) { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + + bool operator==(const FlatMap& x) const { + if (size() != x.size()) return false; + for (auto& p : x) { + auto i = find(p.first); + if (i == end()) return false; + if (i->second != p.second) return false; + } + return true; + } + bool operator!=(const FlatMap& x) const { return !(*this == x); } + + // If key exists in the table, prefetch the associated value. This + // is a hint, and may have no effect. + void prefetch_value(const Key& key) const { rep_.Prefetch(key); } + + private: + using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; + + // Bucket stores kWidth <marker, key, value> triples. + // The data is organized as three parallel arrays to reduce padding. + struct Bucket { + uint8 marker[Rep::kWidth]; + + // Wrap keys and values in union to control construction and destruction. + union Storage { + struct { + Key key[Rep::kWidth]; + Val val[Rep::kWidth]; + }; + Storage() {} + ~Storage() {} + } storage; + + Key& key(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.key[i]; + } + Val& val(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.val[i]; + } + template <typename V> + void InitVal(uint32 i, V&& v) { + new (&storage.val[i]) Val(std::forward<V>(v)); + } + void Destroy(uint32 i) { + storage.key[i].Key::~Key(); + storage.val[i].Val::~Val(); + } + void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); + new (&storage.val[i]) Val(std::move(src->storage.val[src_index])); + } + void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(src->storage.key[src_index]); + new (&storage.val[i]) Val(src->storage.val[src_index]); + } + }; + + template <typename Pair> + std::pair<iterator, bool> InsertPair(Pair&& p) { + return Insert(std::forward<decltype(p.first)>(p.first), + std::forward<decltype(p.second)>(p.second)); + } + + template <typename K, typename V> + std::pair<iterator, bool> Insert(K&& k, V&& v) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward<K>(k)); + const bool inserted = !r.found; + if (inserted) { + r.b->InitVal(r.index, std::forward<V>(v)); + } + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + template <typename K> + Val& IndexOp(K&& k) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward<K>(k)); + Val* vptr = &r.b->val(r.index); + if (!r.found) { + new (vptr) Val(); // Initialize value in new slot. + } + return *vptr; + } + + Rep rep_; +}; + +} // namespace gtl +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ diff --git a/tensorflow/core/lib/gtl/flatmap_test.cc b/tensorflow/core/lib/gtl/flatmap_test.cc new file mode 100644 index 0000000000..2fa610b7e1 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatmap_test.cc @@ -0,0 +1,576 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/gtl/flatmap.h" + +#include <algorithm> +#include <string> +#include <vector> +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { +namespace { + +typedef FlatMap<int64, int32, HashInt64> NumMap; + +// If map has an entry for k, return the corresponding value, else return def. +int32 Get(const NumMap& map, int64 k, int32 def = -1) { + auto iter = map.find(k); + if (iter == map.end()) { + EXPECT_EQ(map.count(k), 0); + return def; + } else { + EXPECT_EQ(map.count(k), 1); + EXPECT_EQ(&map.at(k), &iter->second); + EXPECT_EQ(iter->first, k); + return iter->second; + } +} + +// Return contents of map as a sorted list of pairs. +typedef std::vector<std::pair<int64, int32>> NumMapContents; +NumMapContents Contents(const NumMap& map) { + NumMapContents result; + for (const auto& p : map) { + result.push_back({p.first, p.second}); + } + std::sort(result.begin(), result.end()); + return result; +} + +// Fill entries with keys [start,limit). +void Fill(NumMap* map, int64 start, int64 limit) { + for (int64 i = start; i < limit; i++) { + map->insert({i, i * 100}); + } +} + +TEST(FlatMapTest, Find) { + NumMap map; + EXPECT_EQ(Get(map, 1), -1); + map.insert({1, 100}); + map.insert({2, 200}); + EXPECT_EQ(Get(map, 1), 100); + EXPECT_EQ(Get(map, 2), 200); + EXPECT_EQ(Get(map, 3), -1); +} + +TEST(FlatMapTest, Insert) { + NumMap map; + EXPECT_EQ(Get(map, 1), -1); + + // New entry. + auto result = map.insert({1, 100}); + EXPECT_TRUE(result.second); + EXPECT_EQ(result.first->first, 1); + EXPECT_EQ(result.first->second, 100); + EXPECT_EQ(Get(map, 1), 100); + + // Attempt to insert over existing entry. + result = map.insert({1, 200}); + EXPECT_FALSE(result.second); + EXPECT_EQ(result.first->first, 1); + EXPECT_EQ(result.first->second, 100); + EXPECT_EQ(Get(map, 1), 100); + + // Overwrite through iterator. + result.first->second = 300; + EXPECT_EQ(result.first->second, 300); + EXPECT_EQ(Get(map, 1), 300); + + // Should get updated value. + result = map.insert({1, 400}); + EXPECT_FALSE(result.second); + EXPECT_EQ(result.first->first, 1); + EXPECT_EQ(result.first->second, 300); + EXPECT_EQ(Get(map, 1), 300); +} + +TEST(FlatMapTest, InsertGrowth) { + NumMap map; + const int n = 100; + Fill(&map, 0, 100); + EXPECT_EQ(map.size(), n); + for (int i = 0; i < n; i++) { + EXPECT_EQ(Get(map, i), i * 100) << i; + } +} + +TEST(FlatMapTest, Emplace) { + NumMap map; + + // New entry. + auto result = map.emplace(1, 100); + EXPECT_TRUE(result.second); + EXPECT_EQ(result.first->first, 1); + EXPECT_EQ(result.first->second, 100); + EXPECT_EQ(Get(map, 1), 100); + + // Attempt to insert over existing entry. + result = map.emplace(1, 200); + EXPECT_FALSE(result.second); + EXPECT_EQ(result.first->first, 1); + EXPECT_EQ(result.first->second, 100); + EXPECT_EQ(Get(map, 1), 100); + + // Overwrite through iterator. + result.first->second = 300; + EXPECT_EQ(result.first->second, 300); + EXPECT_EQ(Get(map, 1), 300); + + // Update a second value + result = map.emplace(2, 400); + EXPECT_TRUE(result.second); + EXPECT_EQ(result.first->first, 2); + EXPECT_EQ(result.first->second, 400); + EXPECT_EQ(Get(map, 2), 400); +} + +TEST(FlatMapTest, EmplaceUniquePtr) { + FlatMap<int64, std::unique_ptr<string>, HashInt64> smap; + smap.emplace(1, std::unique_ptr<string>(new string("hello"))); +} + +TEST(FlatMapTest, Size) { + NumMap map; + EXPECT_EQ(map.size(), 0); + + map.insert({1, 100}); + map.insert({2, 200}); + EXPECT_EQ(map.size(), 2); +} + +TEST(FlatMapTest, Empty) { + NumMap map; + EXPECT_TRUE(map.empty()); + + map.insert({1, 100}); + map.insert({2, 200}); + EXPECT_FALSE(map.empty()); +} + +TEST(FlatMapTest, ArrayOperator) { + NumMap map; + + // Create new element if not found. + auto v1 = &map[1]; + EXPECT_EQ(*v1, 0); + EXPECT_EQ(Get(map, 1), 0); + + // Write through returned reference. + *v1 = 100; + EXPECT_EQ(map[1], 100); + EXPECT_EQ(Get(map, 1), 100); + + // Reuse existing element if found. + auto v1a = &map[1]; + EXPECT_EQ(v1, v1a); + EXPECT_EQ(*v1, 100); + + // Create another element. + map[2] = 200; + EXPECT_EQ(Get(map, 1), 100); + EXPECT_EQ(Get(map, 2), 200); +} + +TEST(FlatMapTest, Count) { + NumMap map; + EXPECT_EQ(map.count(1), 0); + EXPECT_EQ(map.count(2), 0); + + map.insert({1, 100}); + EXPECT_EQ(map.count(1), 1); + EXPECT_EQ(map.count(2), 0); + + map.insert({2, 200}); + EXPECT_EQ(map.count(1), 1); + EXPECT_EQ(map.count(2), 1); +} + +TEST(FlatMapTest, Iter) { + NumMap map; + EXPECT_EQ(Contents(map), NumMapContents()); + + map.insert({1, 100}); + map.insert({2, 200}); + EXPECT_EQ(Contents(map), NumMapContents({{1, 100}, {2, 200}})); +} + +TEST(FlatMapTest, Erase) { + NumMap map; + EXPECT_EQ(map.erase(1), 0); + map[1] = 100; + map[2] = 200; + EXPECT_EQ(map.erase(3), 0); + EXPECT_EQ(map.erase(1), 1); + EXPECT_EQ(map.size(), 1); + EXPECT_EQ(Get(map, 2), 200); + EXPECT_EQ(Contents(map), NumMapContents({{2, 200}})); + EXPECT_EQ(map.erase(2), 1); + EXPECT_EQ(Contents(map), NumMapContents()); +} + +TEST(FlatMapTest, EraseIter) { + NumMap map; + Fill(&map, 1, 11); + size_t size = 10; + for (auto iter = map.begin(); iter != map.end();) { + iter = map.erase(iter); + size--; + EXPECT_EQ(map.size(), size); + } + EXPECT_EQ(Contents(map), NumMapContents()); +} + +TEST(FlatMapTest, EraseIterPair) { + NumMap map; + Fill(&map, 1, 11); + NumMap expected; + auto p1 = map.begin(); + expected.insert(*p1); + ++p1; + expected.insert(*p1); + ++p1; + auto p2 = map.end(); + EXPECT_EQ(map.erase(p1, p2), map.end()); + EXPECT_EQ(map.size(), 2); + EXPECT_EQ(Contents(map), Contents(expected)); +} + +TEST(FlatMapTest, EraseLongChains) { + // Make a map with lots of elements and erase a bunch of them to ensure + // that we are likely to hit them on future lookups. + NumMap map; + const int num = 128; + Fill(&map, 0, num); + for (int i = 0; i < num; i += 3) { + EXPECT_EQ(map.erase(i), 1); + } + for (int i = 0; i < num; i++) { + if ((i % 3) != 0) { + EXPECT_EQ(Get(map, i), i * 100); + } else { + EXPECT_EQ(map.count(i), 0); + } + } + + // Erase remainder to trigger table shrinking. + const size_t orig_buckets = map.bucket_count(); + for (int i = 0; i < num; i++) { + map.erase(i); + } + EXPECT_TRUE(map.empty()); + EXPECT_EQ(map.bucket_count(), orig_buckets); + map[1] = 100; // Actual shrinking is triggered by an insert. + EXPECT_LT(map.bucket_count(), orig_buckets); +} + +TEST(FlatMap, AlternatingInsertRemove) { + NumMap map; + map.insert({1000, 1000}); + map.insert({2000, 1000}); + map.insert({3000, 1000}); + for (int i = 0; i < 10000; i++) { + map.insert({i, i}); + map.erase(i); + } +} + +TEST(FlatMap, ClearNoResize) { + NumMap map; + Fill(&map, 0, 100); + const size_t orig = map.bucket_count(); + map.clear_no_resize(); + EXPECT_EQ(map.size(), 0); + EXPECT_EQ(Contents(map), NumMapContents()); + EXPECT_EQ(map.bucket_count(), orig); +} + +TEST(FlatMap, Clear) { + NumMap map; + Fill(&map, 0, 100); + const size_t orig = map.bucket_count(); + map.clear(); + EXPECT_EQ(map.size(), 0); + EXPECT_EQ(Contents(map), NumMapContents()); + EXPECT_LT(map.bucket_count(), orig); +} + +TEST(FlatMap, Copy) { + for (int n = 0; n < 10; n++) { + NumMap src; + Fill(&src, 0, n); + NumMap copy = src; + EXPECT_EQ(Contents(src), Contents(copy)); + NumMap copy2; + copy2 = src; + EXPECT_EQ(Contents(src), Contents(copy2)); + copy2 = copy2; // Self-assignment + EXPECT_EQ(Contents(src), Contents(copy2)); + } +} + +TEST(FlatMap, InitFromIter) { + for (int n = 0; n < 10; n++) { + NumMap src; + Fill(&src, 0, n); + auto vec = Contents(src); + NumMap dst(vec.begin(), vec.end()); + EXPECT_EQ(Contents(dst), vec); + } +} + +TEST(FlatMap, InsertIter) { + NumMap a, b; + Fill(&a, 1, 10); + Fill(&b, 8, 20); + b[9] = 10000; // Should not get inserted into a since a already has 9 + a.insert(b.begin(), b.end()); + NumMap expected; + Fill(&expected, 1, 20); + EXPECT_EQ(Contents(a), Contents(expected)); +} + +TEST(FlatMap, Eq) { + NumMap empty; + + NumMap elems; + Fill(&elems, 0, 5); + EXPECT_FALSE(empty == elems); + EXPECT_TRUE(empty != elems); + + NumMap copy = elems; + EXPECT_TRUE(copy == elems); + EXPECT_FALSE(copy != elems); + + NumMap changed = elems; + changed[3] = 1; + EXPECT_FALSE(changed == elems); + EXPECT_TRUE(changed != elems); + + NumMap changed2 = elems; + changed2.erase(3); + EXPECT_FALSE(changed2 == elems); + EXPECT_TRUE(changed2 != elems); +} + +TEST(FlatMap, Swap) { + NumMap a, b; + Fill(&a, 1, 5); + Fill(&b, 100, 200); + NumMap c = a; + NumMap d = b; + EXPECT_EQ(c, a); + EXPECT_EQ(d, b); + c.swap(d); + EXPECT_EQ(c, b); + EXPECT_EQ(d, a); +} + +TEST(FlatMap, Reserve) { + NumMap src; + Fill(&src, 1, 100); + NumMap a = src; + a.reserve(10); + EXPECT_EQ(a, src); + NumMap b = src; + b.rehash(1000); + EXPECT_EQ(b, src); +} + +TEST(FlatMap, EqualRangeMutable) { + NumMap map; + Fill(&map, 1, 10); + + // Existing element + auto p1 = map.equal_range(3); + EXPECT_TRUE(p1.first != p1.second); + EXPECT_EQ(p1.first->first, 3); + EXPECT_EQ(p1.first->second, 300); + ++p1.first; + EXPECT_TRUE(p1.first == p1.second); + + // Missing element + auto p2 = map.equal_range(100); + EXPECT_TRUE(p2.first == p2.second); +} + +TEST(FlatMap, EqualRangeConst) { + NumMap tmp; + Fill(&tmp, 1, 10); + + const NumMap map = tmp; + + // Existing element + auto p1 = map.equal_range(3); + EXPECT_TRUE(p1.first != p1.second); + EXPECT_EQ(p1.first->first, 3); + EXPECT_EQ(p1.first->second, 300); + ++p1.first; + EXPECT_TRUE(p1.first == p1.second); + + // Missing element + auto p2 = map.equal_range(100); + EXPECT_TRUE(p2.first == p2.second); +} + +TEST(FlatMap, Prefetch) { + NumMap map; + Fill(&map, 0, 1000); + // Prefetch present and missing keys. + for (int i = 0; i < 2000; i++) { + map.prefetch_value(i); + } +} + +// Non-copyable values should work. +struct NC { + int64 value; + NC() : value(-1) {} + NC(int64 v) : value(v) {} + NC(const NC& x) : value(x.value) {} + bool operator==(const NC& x) const { return value == x.value; } +}; +struct HashNC { + size_t operator()(NC x) const { return x.value; } +}; + +TEST(FlatMap, NonCopyable) { + FlatMap<NC, NC, HashNC> map; + for (int i = 0; i < 100; i++) { + map[NC(i)] = NC(i * 100); + } + for (int i = 0; i < 100; i++) { + EXPECT_EQ(map.count(NC(i)), 1); + auto iter = map.find(NC(i)); + EXPECT_NE(iter, map.end()); + EXPECT_EQ(iter->first, NC(i)); + EXPECT_EQ(iter->second, NC(i * 100)); + EXPECT_EQ(map[NC(i)], NC(i * 100)); + } + map.erase(NC(10)); + EXPECT_EQ(map.count(NC(10)), 0); +} + +// Test with heap-allocated objects so that mismanaged constructions +// or destructions will show up as errors under a sanitizer or +// heap checker. +TEST(FlatMap, ConstructDestruct) { + FlatMap<string, string, HashStr> map; + string k1 = "the quick brown fox jumped over the lazy dog"; + string k2 = k1 + k1; + string k3 = k1 + k2; + map[k1] = k2; + map[k3] = k1; + EXPECT_EQ(k1, map.find(k1)->first); + EXPECT_EQ(k2, map.find(k1)->second); + EXPECT_EQ(k1, map[k3]); + map.erase(k3); + EXPECT_EQ(string(), map[k3]); + + map.clear(); + map[k1] = k2; + EXPECT_EQ(k2, map[k1]); + + map.reserve(100); + EXPECT_EQ(k2, map[k1]); +} + +// Type to use to ensure that custom equality operator is used +// that ignores extra value. +struct CustomCmpKey { + int64 a; + int64 b; + CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {} + bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; } +}; +struct HashA { + size_t operator()(CustomCmpKey x) const { return x.a; } +}; +struct EqA { + // Ignore b fields. + bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; } +}; +TEST(FlatMap, CustomCmp) { + FlatMap<CustomCmpKey, int, HashA, EqA> map; + map[CustomCmpKey(100, 200)] = 300; + EXPECT_EQ(300, map[CustomCmpKey(100, 200)]); + EXPECT_EQ(300, map[CustomCmpKey(100, 500)]); // Differences in key.b ignored +} + +// Test unique_ptr handling. +typedef std::unique_ptr<int> UniqInt; +static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); } + +struct HashUniq { + size_t operator()(const UniqInt& p) const { return *p; } +}; +struct EqUniq { + bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; } +}; +typedef FlatMap<UniqInt, UniqInt, HashUniq, EqUniq> UniqMap; + +TEST(FlatMap, UniqueMap) { + UniqMap map; + + // Fill map + const int N = 10; + for (int i = 0; i < N; i++) { + if ((i % 2) == 0) { + map[MakeUniq(i)] = MakeUniq(i + 100); + } else { + map.emplace(MakeUniq(i), MakeUniq(i + 100)); + } + } + EXPECT_EQ(map.size(), N); + + // Lookups + for (int i = 0; i < N; i++) { + EXPECT_EQ(*map.at(MakeUniq(i)), i + 100); + } + + // find+erase + EXPECT_EQ(map.count(MakeUniq(2)), 1); + map.erase(MakeUniq(2)); + EXPECT_EQ(map.count(MakeUniq(2)), 0); + + // clear + map.clear(); + EXPECT_EQ(map.size(), 0); +} + +TEST(FlatMap, UniqueMapIter) { + UniqMap map; + const int kCount = 10; + const int kValueDelta = 100; + for (int i = 1; i <= kCount; i++) { + map[MakeUniq(i)] = MakeUniq(i + kValueDelta); + } + int key_sum = 0; + int val_sum = 0; + for (const auto& p : map) { + key_sum += *p.first; + val_sum += *p.second; + } + EXPECT_EQ(key_sum, (kCount * (kCount + 1)) / 2); + EXPECT_EQ(val_sum, key_sum + (kCount * kValueDelta)); +} + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h new file mode 100644 index 0000000000..ff590d4128 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatrep.h @@ -0,0 +1,332 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ + +#include <string.h> +#include <utility> +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { +namespace internal { + +// Internal representation for FlatMap and FlatSet. +// +// The representation is an open-addressed hash table. Conceptually, +// the representation is a flat array of entries. However we +// structure it as an array of of buckets where each bucket holds +// kWidth entries along with metadata for the kWidth entries. The +// metadata marker is +// +// (a) kEmpty: the entry is empty +// (b) kDeleted: the entry has been deleted +// (c) other: the entry is occupied and has low-8 bits of its hash. +// These hash bits can be used to avoid potentially expensive +// key comparisons. +// +// FlatMap passes in a bucket that contains keys and values, FlatSet +// passes in a bucket that does not contain values. +template <typename Key, typename Bucket, class Hash, class Eq> +class FlatRep { + public: + // kWidth is the number of entries stored in a bucket. + static const uint32 kBase = 3; + static const uint32 kWidth = (1 << kBase); + + FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) { + Init(N); + } + explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) { + Init(src.size()); + CopyEntries(src.array_, src.end_, CopyEntry()); + } + ~FlatRep() { + clear_no_resize(); + delete[] array_; + } + + // Simple accessors. + size_t size() const { return not_empty_ - deleted_; } + size_t bucket_count() const { return mask_ + 1; } + Bucket* start() const { return array_; } + Bucket* limit() const { return end_; } + const Hash& hash_function() const { return hash_; } + const Eq& key_eq() const { return equal_; } + + // Overwrite contents of *this with contents of src. + void CopyFrom(const FlatRep& src) { + if (this != &src) { + clear_no_resize(); + delete[] array_; + Init(src.size()); + CopyEntries(src.array_, src.end_, CopyEntry()); + } + } + + void clear_no_resize() { + for (Bucket* b = array_; b != end_; b++) { + for (uint32 i = 0; i < kWidth; i++) { + if (b->marker[i] >= 2) { + b->Destroy(i); + b->marker[i] = kEmpty; + } + } + } + not_empty_ = 0; + deleted_ = 0; + } + + void clear() { + clear_no_resize(); + grow_ = 0; // Consider shrinking in MaybeResize() + MaybeResize(); + } + + void swap(FlatRep& x) { + using std::swap; + swap(array_, x.array_); + swap(end_, x.end_); + swap(lglen_, x.lglen_); + swap(mask_, x.mask_); + swap(not_empty_, x.not_empty_); + swap(deleted_, x.deleted_); + swap(grow_, x.grow_); + swap(shrink_, x.shrink_); + } + + struct SearchResult { + bool found; + Bucket* b; + uint32 index; + }; + + // Hash value is partitioned as follows: + // 1. Bottom 8 bits are stored in bucket to help speed up comparisons. + // 2. Next 3 bits give index inside bucket. + // 3. Remaining bits give bucket number. + + // Find bucket/index for key k. + SearchResult Find(const Key& k) const { + size_t h = hash_(k); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == marker && equal_(b->key(bi), k)) { + return {true, b, bi}; + } else if (x == kEmpty) { + return {false, nullptr, 0}; + } + // Quadratic probing. + index = (index + num_probes) & mask_; + num_probes++; + } + } + + // Find bucket/index for key k, creating a new one if necessary. + // + // KeyType is a template parameter so that k's type is deduced and it + // becomes a universal reference which allows the key initialization + // below to use an rvalue constructor if available. + template <typename KeyType> + SearchResult FindOrInsert(KeyType&& k) { + size_t h = hash_(k); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + Bucket* del = nullptr; // First encountered deletion for kInsert + uint32 di = 0; + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == marker && equal_(b->key(bi), k)) { + return {true, b, bi}; + } else if (!del && x == kDeleted) { + // Remember deleted index to use for insertion. + del = b; + di = bi; + } else if (x == kEmpty) { + if (del) { + // Store in the first deleted slot we encountered + b = del; + bi = di; + deleted_--; // not_empty_ does not change + } else { + not_empty_++; + } + b->marker[bi] = marker; + new (&b->key(bi)) Key(std::forward<KeyType>(k)); + return {false, b, bi}; + } + // Quadratic probing. + index = (index + num_probes) & mask_; + num_probes++; + } + } + + void Erase(Bucket* b, uint32 i) { + b->Destroy(i); + b->marker[i] = kDeleted; + deleted_++; + grow_ = 0; // Consider shrinking on next insert + } + + void Prefetch(const Key& k) const { + size_t h = hash_(k); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + prefetch(&b->storage.key[bi]); + } + void prefetch(const void* ptr) const { + // TODO(jeff,sanjay): Remove this routine when we add a + // prefetch(...) call to platform so that the Prefetch routine + // actually does something + } + + inline void MaybeResize() { + if (not_empty_ < grow_) { + return; // Nothing to do + } + if (grow_ == 0) { + // Special value set by erase to cause shrink on next insert. + if (size() >= shrink_) { + // Not small enough to shrink. + grow_ = static_cast<size_t>(bucket_count() * 0.8); + if (not_empty_ < grow_) return; + } + } + Resize(size() + 1); + } + + void Resize(size_t N) { + Bucket* old = array_; + Bucket* old_end = end_; + Init(N); + CopyEntries(old, old_end, MoveEntry()); + delete[] old; + } + + private: + enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry. + + Hash hash_; // User-supplied hasher + Eq equal_; // User-supplied comparator + uint8 lglen_; // lg(#buckets) + Bucket* array_; // array of length (1 << lglen_) + Bucket* end_; // Points just past last bucket in array_ + size_t mask_; // (# of entries in table) - 1 + size_t not_empty_; // Count of entries with marker != kEmpty + size_t deleted_; // Count of entries with marker == kDeleted + size_t grow_; // Grow array when not_empty_ >= grow_ + size_t shrink_; // Shrink array when size() < shrink_ + + // Avoid kEmpty and kDeleted markers when computing hash values to + // store in Bucket::marker[]. + static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); } + + void Init(size_t N) { + // Make enough room for N elements. + size_t lg = 0; // Smallest table is just one bucket. + while (N >= 0.8 * ((1 << lg) * kWidth)) { + lg++; + } + const size_t n = (1 << lg); + Bucket* array = new Bucket[n]; + for (size_t i = 0; i < n; i++) { + Bucket* b = &array[i]; + memset(b->marker, kEmpty, kWidth); + } + const size_t capacity = (1 << lg) * kWidth; + lglen_ = lg; + mask_ = capacity - 1; + array_ = array; + end_ = array + n; + not_empty_ = 0; + deleted_ = 0; + grow_ = static_cast<size_t>(capacity * 0.8); + if (lg == 0) { + // Already down to one bucket; no more shrinking. + shrink_ = 0; + } else { + shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5 + } + } + + // Used by FreshInsert when we should copy from source. + struct CopyEntry { + inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { + dst->CopyFrom(dsti, src, srci); + } + }; + + // Used by FreshInsert when we should move from source. + struct MoveEntry { + inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { + dst->MoveFrom(dsti, src, srci); + src->Destroy(srci); + src->marker[srci] = kDeleted; + } + }; + + template <typename Copier> + void CopyEntries(Bucket* start, Bucket* end, Copier copier) { + for (Bucket* b = start; b != end; b++) { + for (uint32 i = 0; i < kWidth; i++) { + if (b->marker[i] >= 2) { + FreshInsert(b, i, copier); + } + } + } + } + + // Create an entry for the key numbered src_index in *src and return + // its bucket/index. Used for insertion into a fresh table. We + // assume that there are no deletions, and k does not already exist + // in the table. + template <typename Copier> + void FreshInsert(Bucket* src, uint32 src_index, Copier copier) { + size_t h = hash_(src->key(src_index)); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == 0) { + b->marker[bi] = marker; + not_empty_++; + copier(b, bi, src, src_index); + return; + } + // Quadratic probing. + index = (index + num_probes) & mask_; + num_probes++; + } + } +}; + +} // namespace internal +} // namespace gtl +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h new file mode 100644 index 0000000000..b94d88cbc6 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatset.h @@ -0,0 +1,277 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ + +#include <stddef.h> +#include <utility> +#include "tensorflow/core/lib/gtl/flatrep.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { + +// FlatSet<K,...> provides a set of K. +// +// The map is implemented using an open-addressed hash table. A +// single array holds entire map contents and collisions are resolved +// by probing at a sequence of locations in the array. +template <typename Key, class Hash, class Eq = std::equal_to<Key>> +class FlatSet { + private: + // Forward declare some internal types needed in public section. + struct Bucket; + + public: + typedef Key key_type; + typedef Key value_type; + typedef Hash hasher; + typedef Eq key_equal; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + + FlatSet() : FlatSet(1) {} + + explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) + : rep_(N, hf, eq) {} + + FlatSet(const FlatSet& src) : rep_(src.rep_) {} + + template <typename InputIter> + FlatSet(InputIter first, InputIter last, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatSet(N, hf, eq) { + insert(first, last); + } + + FlatSet& operator=(const FlatSet& src) { + rep_.CopyFrom(src.rep_); + return *this; + } + + ~FlatSet() {} + + void swap(FlatSet& x) { rep_.swap(x.rep_); } + void clear_no_resize() { rep_.clear_no_resize(); } + void clear() { rep_.clear(); } + void reserve(size_t N) { rep_.Resize(std::max(N, size())); } + void rehash(size_t N) { rep_.Resize(std::max(N, size())); } + void resize(size_t N) { rep_.Resize(std::max(N, size())); } + size_t size() const { return rep_.size(); } + bool empty() const { return size() == 0; } + size_t bucket_count() const { return rep_.bucket_count(); } + hasher hash_function() const { return rep_.hash_function(); } + key_equal key_eq() const { return rep_.key_eq(); } + + class iterator { + public: + iterator() : b_(nullptr), end_(nullptr), i_(0) {} + + // Make iterator pointing at first element at or after b. + explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { + SkipUnused(); + } + + // Make iterator pointing exactly at ith element in b, which must exist. + iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {} + + Key& operator*() { return key(); } + Key* operator->() { return &key(); } + bool operator==(const iterator& x) const { + return b_ == x.b_ && i_ == x.i_; + } + bool operator!=(const iterator& x) const { return !(*this == x); } + iterator& operator++() { + DCHECK(b_ != end_); + i_++; + SkipUnused(); + return *this; + } + + private: + friend class FlatSet; + Bucket* b_; + Bucket* end_; + uint32 i_; + + Key& key() const { return b_->key(i_); } + void SkipUnused() { + while (b_ < end_) { + if (i_ >= Rep::kWidth) { + i_ = 0; + b_++; + } else if (b_->marker[i_] < 2) { + i_++; + } else { + break; + } + } + } + }; + + class const_iterator { + private: + mutable iterator rep_; // Share state and logic with non-const iterator. + public: + const_iterator() : rep_() {} + explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} + const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} + + const Key& operator*() const { return rep_.key(); } + const Key* operator->() const { return &rep_.key(); } + bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } + bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } + const_iterator& operator++() { + ++rep_; + return *this; + } + }; + + iterator begin() { return iterator(rep_.start(), rep_.limit()); } + iterator end() { return iterator(rep_.limit(), rep_.limit()); } + const_iterator begin() const { + return const_iterator(rep_.start(), rep_.limit()); + } + const_iterator end() const { + return const_iterator(rep_.limit(), rep_.limit()); + } + + size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } + iterator find(const Key& k) { + auto r = rep_.Find(k); + return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); + } + const_iterator find(const Key& k) const { + auto r = rep_.Find(k); + return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); + } + + std::pair<iterator, bool> insert(const Key& k) { return Insert(k); } + template <typename InputIter> + void insert(InputIter first, InputIter last) { + for (; first != last; ++first) { + insert(*first); + } + } + + template <typename... Args> + std::pair<iterator, bool> emplace(Args&&... args) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward<Args>(args)...); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + size_t erase(const Key& k) { + auto r = rep_.Find(k); + if (!r.found) return 0; + rep_.Erase(r.b, r.index); + return 1; + } + iterator erase(iterator pos) { + rep_.Erase(pos.b_, pos.i_); + ++pos; + return pos; + } + iterator erase(iterator pos, iterator last) { + for (; pos != last; ++pos) { + rep_.Erase(pos.b_, pos.i_); + } + return pos; + } + + std::pair<iterator, iterator> equal_range(const Key& k) { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + + bool operator==(const FlatSet& x) const { + if (size() != x.size()) return false; + for (const auto& elem : x) { + auto i = find(elem); + if (i == end()) return false; + } + return true; + } + bool operator!=(const FlatSet& x) const { return !(*this == x); } + + // If key exists in the table, prefetch it. This is a hint, and may + // have no effect. + void prefetch_value(const Key& key) const { rep_.Prefetch(key); } + + private: + using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; + + // Bucket stores kWidth <marker, key, value> triples. + // The data is organized as three parallel arrays to reduce padding. + struct Bucket { + uint8 marker[Rep::kWidth]; + + // Wrap keys in union to control construction and destruction. + union Storage { + Key key[Rep::kWidth]; + Storage() {} + ~Storage() {} + } storage; + + Key& key(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.key[i]; + } + void Destroy(uint32 i) { storage.key[i].Key::~Key(); } + void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); + } + void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(src->storage.key[src_index]); + } + }; + + std::pair<iterator, bool> Insert(const Key& k) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(k); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + Rep rep_; +}; + +} // namespace gtl +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ diff --git a/tensorflow/core/lib/gtl/flatset_test.cc b/tensorflow/core/lib/gtl/flatset_test.cc new file mode 100644 index 0000000000..ea9c9c22b5 --- /dev/null +++ b/tensorflow/core/lib/gtl/flatset_test.cc @@ -0,0 +1,501 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/gtl/flatset.h" + +#include <algorithm> +#include <string> +#include <vector> +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { +namespace { + +typedef FlatSet<int64, HashInt64> NumSet; + +// Returns true iff set has an entry for k. +// Also verifies that find and count give consistent results. +bool Has(const NumSet& set, int64 k) { + auto iter = set.find(k); + if (iter == set.end()) { + EXPECT_EQ(set.count(k), 0); + return false; + } else { + EXPECT_EQ(set.count(k), 1); + EXPECT_EQ(*iter, k); + return true; + } +} + +// Return contents of set as a sorted list of numbers. +typedef std::vector<int64> NumSetContents; +NumSetContents Contents(const NumSet& set) { + NumSetContents result; + for (int64 n : set) { + result.push_back(n); + } + std::sort(result.begin(), result.end()); + return result; +} + +// Fill entries with keys [start,limit). +void Fill(NumSet* set, int64 start, int64 limit) { + for (int64 i = start; i < limit; i++) { + set->insert(i); + } +} + +TEST(FlatSetTest, Find) { + NumSet set; + EXPECT_FALSE(Has(set, 1)); + set.insert(1); + set.insert(2); + EXPECT_TRUE(Has(set, 1)); + EXPECT_TRUE(Has(set, 2)); + EXPECT_FALSE(Has(set, 3)); +} + +TEST(FlatSetTest, Insert) { + NumSet set; + EXPECT_FALSE(Has(set, 1)); + + // New entry. + auto result = set.insert(1); + EXPECT_TRUE(result.second); + EXPECT_EQ(*result.first, 1); + EXPECT_TRUE(Has(set, 1)); + + // Attempt to insert over existing entry. + result = set.insert(1); + EXPECT_FALSE(result.second); + EXPECT_EQ(*result.first, 1); + EXPECT_TRUE(Has(set, 1)); +} + +TEST(FlatSetTest, InsertGrowth) { + NumSet set; + const int n = 100; + Fill(&set, 0, 100); + EXPECT_EQ(set.size(), n); + for (int i = 0; i < n; i++) { + EXPECT_TRUE(Has(set, i)) << i; + } +} + +TEST(FlatSetTest, Emplace) { + NumSet set; + + // New entry. + auto result = set.emplace(73); + EXPECT_TRUE(result.second); + EXPECT_EQ(*result.first, 73); + EXPECT_TRUE(Has(set, 73)); + + // Attempt to insert an existing entry. + result = set.emplace(73); + EXPECT_FALSE(result.second); + EXPECT_EQ(*result.first, 73); + EXPECT_TRUE(Has(set, 73)); + + // Add a second value + result = set.emplace(103); + EXPECT_TRUE(result.second); + EXPECT_EQ(*result.first, 103); + EXPECT_TRUE(Has(set, 103)); +} + +TEST(FlatSetTest, Size) { + NumSet set; + EXPECT_EQ(set.size(), 0); + + set.insert(1); + set.insert(2); + EXPECT_EQ(set.size(), 2); +} + +TEST(FlatSetTest, Empty) { + NumSet set; + EXPECT_TRUE(set.empty()); + + set.insert(1); + set.insert(2); + EXPECT_FALSE(set.empty()); +} + +TEST(FlatSetTest, Count) { + NumSet set; + EXPECT_EQ(set.count(1), 0); + EXPECT_EQ(set.count(2), 0); + + set.insert(1); + EXPECT_EQ(set.count(1), 1); + EXPECT_EQ(set.count(2), 0); + + set.insert(2); + EXPECT_EQ(set.count(1), 1); + EXPECT_EQ(set.count(2), 1); +} + +TEST(FlatSetTest, Iter) { + NumSet set; + EXPECT_EQ(Contents(set), NumSetContents()); + + set.insert(1); + set.insert(2); + EXPECT_EQ(Contents(set), NumSetContents({1, 2})); +} + +TEST(FlatSetTest, Erase) { + NumSet set; + EXPECT_EQ(set.erase(1), 0); + set.insert(1); + set.insert(2); + EXPECT_EQ(set.erase(3), 0); + EXPECT_EQ(set.erase(1), 1); + EXPECT_EQ(set.size(), 1); + EXPECT_TRUE(Has(set, 2)); + EXPECT_EQ(Contents(set), NumSetContents({2})); + EXPECT_EQ(set.erase(2), 1); + EXPECT_EQ(Contents(set), NumSetContents()); +} + +TEST(FlatSetTest, EraseIter) { + NumSet set; + Fill(&set, 1, 11); + size_t size = 10; + for (auto iter = set.begin(); iter != set.end();) { + iter = set.erase(iter); + size--; + EXPECT_EQ(set.size(), size); + } + EXPECT_EQ(Contents(set), NumSetContents()); +} + +TEST(FlatSetTest, EraseIterPair) { + NumSet set; + Fill(&set, 1, 11); + NumSet expected; + auto p1 = set.begin(); + expected.insert(*p1); + ++p1; + expected.insert(*p1); + ++p1; + auto p2 = set.end(); + EXPECT_EQ(set.erase(p1, p2), set.end()); + EXPECT_EQ(set.size(), 2); + EXPECT_EQ(Contents(set), Contents(expected)); +} + +TEST(FlatSetTest, EraseLongChains) { + // Make a set with lots of elements and erase a bunch of them to ensure + // that we are likely to hit them on future lookups. + NumSet set; + const int num = 128; + Fill(&set, 0, num); + for (int i = 0; i < num; i += 3) { + EXPECT_EQ(set.erase(i), 1); + } + for (int i = 0; i < num; i++) { + // Multiples of 3 should be not present. + EXPECT_EQ(Has(set, i), ((i % 3) != 0)) << i; + } + + // Erase remainder to trigger table shrinking. + const size_t orig_buckets = set.bucket_count(); + for (int i = 0; i < num; i++) { + set.erase(i); + } + EXPECT_TRUE(set.empty()); + EXPECT_EQ(set.bucket_count(), orig_buckets); + set.insert(1); // Actual shrinking is triggered by an insert. + EXPECT_LT(set.bucket_count(), orig_buckets); +} + +TEST(FlatSet, ClearNoResize) { + NumSet set; + Fill(&set, 0, 100); + const size_t orig = set.bucket_count(); + set.clear_no_resize(); + EXPECT_EQ(set.size(), 0); + EXPECT_EQ(Contents(set), NumSetContents()); + EXPECT_EQ(set.bucket_count(), orig); +} + +TEST(FlatSet, Clear) { + NumSet set; + Fill(&set, 0, 100); + const size_t orig = set.bucket_count(); + set.clear(); + EXPECT_EQ(set.size(), 0); + EXPECT_EQ(Contents(set), NumSetContents()); + EXPECT_LT(set.bucket_count(), orig); +} + +TEST(FlatSet, Copy) { + for (int n = 0; n < 10; n++) { + NumSet src; + Fill(&src, 0, n); + NumSet copy = src; + EXPECT_EQ(Contents(src), Contents(copy)); + NumSet copy2; + copy2 = src; + EXPECT_EQ(Contents(src), Contents(copy2)); + copy2 = copy2; // Self-assignment + EXPECT_EQ(Contents(src), Contents(copy2)); + } +} + +TEST(FlatSet, InitFromIter) { + for (int n = 0; n < 10; n++) { + NumSet src; + Fill(&src, 0, n); + auto vec = Contents(src); + NumSet dst(vec.begin(), vec.end()); + EXPECT_EQ(Contents(dst), vec); + } +} + +TEST(FlatSet, InsertIter) { + NumSet a, b; + Fill(&a, 1, 10); + Fill(&b, 8, 20); + b.insert(9); // Should not get inserted into a since a already has 9 + a.insert(b.begin(), b.end()); + NumSet expected; + Fill(&expected, 1, 20); + EXPECT_EQ(Contents(a), Contents(expected)); +} + +TEST(FlatSet, Eq) { + NumSet empty; + + NumSet elems; + Fill(&elems, 0, 5); + EXPECT_FALSE(empty == elems); + EXPECT_TRUE(empty != elems); + + NumSet copy = elems; + EXPECT_TRUE(copy == elems); + EXPECT_FALSE(copy != elems); + + NumSet changed = elems; + changed.insert(7); + EXPECT_FALSE(changed == elems); + EXPECT_TRUE(changed != elems); + + NumSet changed2 = elems; + changed2.erase(3); + EXPECT_FALSE(changed2 == elems); + EXPECT_TRUE(changed2 != elems); +} + +TEST(FlatSet, Swap) { + NumSet a, b; + Fill(&a, 1, 5); + Fill(&b, 100, 200); + NumSet c = a; + NumSet d = b; + EXPECT_EQ(c, a); + EXPECT_EQ(d, b); + c.swap(d); + EXPECT_EQ(c, b); + EXPECT_EQ(d, a); +} + +TEST(FlatSet, Reserve) { + NumSet src; + Fill(&src, 1, 100); + NumSet a = src; + a.reserve(10); + EXPECT_EQ(a, src); + NumSet b = src; + b.rehash(1000); + EXPECT_EQ(b, src); +} + +TEST(FlatSet, EqualRangeMutable) { + NumSet set; + Fill(&set, 1, 10); + + // Existing element + auto p1 = set.equal_range(3); + EXPECT_TRUE(p1.first != p1.second); + EXPECT_EQ(*p1.first, 3); + ++p1.first; + EXPECT_TRUE(p1.first == p1.second); + + // Missing element + auto p2 = set.equal_range(100); + EXPECT_TRUE(p2.first == p2.second); +} + +TEST(FlatSet, EqualRangeConst) { + NumSet tmp; + Fill(&tmp, 1, 10); + + const NumSet set = tmp; + + // Existing element + auto p1 = set.equal_range(3); + EXPECT_TRUE(p1.first != p1.second); + EXPECT_EQ(*p1.first, 3); + ++p1.first; + EXPECT_TRUE(p1.first == p1.second); + + // Missing element + auto p2 = set.equal_range(100); + EXPECT_TRUE(p2.first == p2.second); +} + +TEST(FlatSet, Prefetch) { + NumSet set; + Fill(&set, 0, 1000); + // Prefetch present and missing keys. + for (int i = 0; i < 2000; i++) { + set.prefetch_value(i); + } +} + +// Non-copyable values should work. +struct NC { + int64 value; + NC() : value(-1) {} + NC(int64 v) : value(v) {} + NC(const NC& x) : value(x.value) {} + bool operator==(const NC& x) const { return value == x.value; } +}; +struct HashNC { + size_t operator()(NC x) const { return x.value; } +}; + +TEST(FlatSet, NonCopyable) { + FlatSet<NC, HashNC> set; + for (int i = 0; i < 100; i++) { + set.insert(NC(i)); + } + for (int i = 0; i < 100; i++) { + EXPECT_EQ(set.count(NC(i)), 1); + auto iter = set.find(NC(i)); + EXPECT_NE(iter, set.end()); + EXPECT_EQ(*iter, NC(i)); + } + set.erase(NC(10)); + EXPECT_EQ(set.count(NC(10)), 0); +} + +// Test with heap-allocated objects so that mismanaged constructions +// or destructions will show up as errors under a sanitizer or +// heap checker. +TEST(FlatSet, ConstructDestruct) { + FlatSet<string, HashStr> set; + string k1 = "the quick brown fox jumped over the lazy dog"; + string k2 = k1 + k1; + string k3 = k1 + k2; + set.insert(k1); + set.insert(k3); + EXPECT_EQ(set.count(k1), 1); + EXPECT_EQ(set.count(k2), 0); + EXPECT_EQ(set.count(k3), 1); + + set.erase(k3); + EXPECT_EQ(set.count(k3), 0); + + set.clear(); + set.insert(k1); + EXPECT_EQ(set.count(k1), 1); + EXPECT_EQ(set.count(k3), 0); + + set.reserve(100); + EXPECT_EQ(set.count(k1), 1); + EXPECT_EQ(set.count(k3), 0); +} + +// Type to use to ensure that custom equality operator is used +// that ignores extra value. +struct CustomCmpKey { + int64 a; + int64 b; + CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {} + bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; } +}; +struct HashA { + size_t operator()(CustomCmpKey x) const { return x.a; } +}; +struct EqA { + // Ignore b fields. + bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; } +}; +TEST(FlatSet, CustomCmp) { + FlatSet<CustomCmpKey, HashA, EqA> set; + set.insert(CustomCmpKey(100, 200)); + EXPECT_EQ(set.count(CustomCmpKey(100, 200)), 1); + EXPECT_EQ(set.count(CustomCmpKey(100, 500)), 1); // key.b ignored +} + +// Test unique_ptr handling. +typedef std::unique_ptr<int> UniqInt; +static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); } + +struct HashUniq { + size_t operator()(const UniqInt& p) const { return *p; } +}; +struct EqUniq { + bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; } +}; +typedef FlatSet<UniqInt, HashUniq, EqUniq> UniqSet; + +TEST(FlatSet, UniqueSet) { + UniqSet set; + + // Fill set + const int N = 10; + for (int i = 0; i < N; i++) { + set.emplace(MakeUniq(i)); + } + EXPECT_EQ(set.size(), N); + + // Lookups + for (int i = 0; i < N; i++) { + EXPECT_EQ(set.count(MakeUniq(i)), 1); + } + + // erase + set.erase(MakeUniq(2)); + EXPECT_EQ(set.count(MakeUniq(2)), 0); + + // clear + set.clear(); + EXPECT_EQ(set.size(), 0); +} + +TEST(FlatSet, UniqueSetIter) { + UniqSet set; + const int kCount = 10; + for (int i = 1; i <= kCount; i++) { + set.emplace(MakeUniq(i)); + } + int sum = 0; + for (const auto& p : set) { + sum += *p; + } + EXPECT_EQ(sum, (kCount * (kCount + 1)) / 2); +} + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h index 3c71e7d6cc..4e64c90d62 100644 --- a/tensorflow/core/lib/hash/hash.h +++ b/tensorflow/core/lib/hash/hash.h @@ -42,6 +42,24 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) { return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); } +// Convenience Hash functors +struct HashInt64 { + size_t operator()(int64 x) const { return static_cast<size_t>(x); } +}; +struct HashStr { + size_t operator()(const string& s) const { + return static_cast<size_t>(Hash64(s)); + } +}; +template <typename PTR> +struct HashPtr { + size_t operator()(const PTR p) const { + // Hash pointers as integers, but bring more entropy to the lower bits. + size_t k = static_cast<size_t>(reinterpret_cast<uintptr_t>(p)); + return k + (k >> 6); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_LIB_HASH_HASH_H_ |