aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-27 08:10:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 09:21:12 -0700
commit80aec93166dadb2dc30250e1251ab3eb006c2d53 (patch)
treec34632f2f255eda56214559ff7d7d54878a6a6c9
parente43eaf662db492c909e6cab8c954178b75f7b63d (diff)
Added new tensorflow::gtl::FlatMap and tensorflow::gtl::FlatSet classes.
Mostly drop-in replacements for std::unordered_map and std::unordered_set, but much faster (does not do an allocation per entry, and represents entries in groups of 8 in a flat array, which is much more cache efficient). Benchmarks not included in this cl show about 3X to 5X performance improvements over the std::unordered_{set,map} for many kinds of common maps e.g. std::unordered_mapmap<int64, int64> or std::unordered_map<string, int64>. Change: 137401863
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/lib/gtl/flatmap.h349
-rw-r--r--tensorflow/core/lib/gtl/flatmap_test.cc576
-rw-r--r--tensorflow/core/lib/gtl/flatrep.h332
-rw-r--r--tensorflow/core/lib/gtl/flatset.h277
-rw-r--r--tensorflow/core/lib/gtl/flatset_test.cc501
-rw-r--r--tensorflow/core/lib/hash/hash.h18
7 files changed, 2057 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3eea01363b..76e6ee7568 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -164,6 +164,8 @@ cc_library(
"lib/core/threadpool.h",
"lib/gtl/array_slice.h",
"lib/gtl/cleanup.h",
+ "lib/gtl/flatmap.h",
+ "lib/gtl/flatset.h",
"lib/gtl/inlined_vector.h",
"lib/gtl/priority_queue_util.h",
"lib/hash/crc32c.h",
@@ -1447,6 +1449,8 @@ tf_cc_tests(
"lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/edit_distance_test.cc",
+ "lib/gtl/flatmap_test.cc",
+ "lib/gtl/flatset_test.cc",
"lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h
new file mode 100644
index 0000000000..c66bc47168
--- /dev/null
+++ b/tensorflow/core/lib/gtl/flatmap.h
@@ -0,0 +1,349 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
+
+#include <stddef.h>
+#include <utility>
+#include "tensorflow/core/lib/gtl/flatrep.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// FlatMap<K,V,...> provides a map from K to V.
+//
+// The map is implemented using an open-addressed hash table. A
+// single array holds entire map contents and collisions are resolved
+// by probing at a sequence of locations in the array.
+template <typename Key, typename Val, class Hash, class Eq = std::equal_to<Key>>
+class FlatMap {
+ private:
+ // Forward declare some internal types needed in public section.
+ struct Bucket;
+
+ public:
+ typedef Key key_type;
+ typedef Val mapped_type;
+ typedef Hash hasher;
+ typedef Eq key_equal;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+ // We cannot use std::pair<> since internal representation stores
+ // keys and values in separate arrays, so we make a custom struct
+ // that holds references to the internal key, value elements.
+ struct value_type {
+ typedef Key first_type;
+ typedef Val second_type;
+
+ const Key& first;
+ Val& second;
+ value_type(const Key& k, Val& v) : first(k), second(v) {}
+ };
+ typedef value_type* pointer;
+ typedef const value_type* const_pointer;
+ typedef value_type& reference;
+ typedef const value_type& const_reference;
+
+ FlatMap() : FlatMap(1) {}
+
+ explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
+ : rep_(N, hf, eq) {}
+
+ FlatMap(const FlatMap& src) : rep_(src.rep_) {}
+
+ template <typename InputIter>
+ FlatMap(InputIter first, InputIter last, size_t N = 1,
+ const Hash& hf = Hash(), const Eq& eq = Eq())
+ : FlatMap(N, hf, eq) {
+ insert(first, last);
+ }
+
+ FlatMap& operator=(const FlatMap& src) {
+ rep_.CopyFrom(src.rep_);
+ return *this;
+ }
+
+ ~FlatMap() {}
+
+ void swap(FlatMap& x) { rep_.swap(x.rep_); }
+ void clear_no_resize() { rep_.clear_no_resize(); }
+ void clear() { rep_.clear(); }
+ void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
+ void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
+ void resize(size_t N) { rep_.Resize(std::max(N, size())); }
+ size_t size() const { return rep_.size(); }
+ bool empty() const { return size() == 0; }
+ size_t bucket_count() const { return rep_.bucket_count(); }
+ hasher hash_function() const { return rep_.hash_function(); }
+ key_equal key_eq() const { return rep_.key_eq(); }
+
+ class iterator {
+ public:
+ iterator() : b_(nullptr), end_(nullptr), i_(0) {}
+
+ // Make iterator pointing at first element at or after b.
+ explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
+ SkipUnused();
+ }
+
+ // Make iterator pointing exactly at ith element in b, which must exist.
+ iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
+ FillValue();
+ }
+
+ value_type& operator*() { return *val(); }
+ value_type* operator->() { return val(); }
+ bool operator==(const iterator& x) const {
+ return b_ == x.b_ && i_ == x.i_;
+ }
+ bool operator!=(const iterator& x) const { return !(*this == x); }
+ iterator& operator++() {
+ DCHECK(b_ != end_);
+ i_++;
+ SkipUnused();
+ return *this;
+ }
+
+ private:
+ friend class FlatMap;
+ Bucket* b_;
+ Bucket* end_;
+ uint32 i_;
+ char space_[sizeof(value_type)];
+
+ value_type* val() { return reinterpret_cast<value_type*>(space_); }
+ void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
+ void SkipUnused() {
+ while (b_ < end_) {
+ if (i_ >= Rep::kWidth) {
+ i_ = 0;
+ b_++;
+ } else if (b_->marker[i_] < 2) {
+ i_++;
+ } else {
+ FillValue();
+ break;
+ }
+ }
+ }
+ };
+
+ class const_iterator {
+ private:
+ mutable iterator rep_; // Share state and logic with non-const iterator.
+ public:
+ const_iterator() : rep_() {}
+ explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
+ const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
+
+ const value_type& operator*() const { return *rep_.val(); }
+ const value_type* operator->() const { return rep_.val(); }
+ bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
+ bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
+ const_iterator& operator++() {
+ ++rep_;
+ return *this;
+ }
+ };
+
+ iterator begin() { return iterator(rep_.start(), rep_.limit()); }
+ iterator end() { return iterator(rep_.limit(), rep_.limit()); }
+ const_iterator begin() const {
+ return const_iterator(rep_.start(), rep_.limit());
+ }
+ const_iterator end() const {
+ return const_iterator(rep_.limit(), rep_.limit());
+ }
+
+ size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
+ iterator find(const Key& k) {
+ auto r = rep_.Find(k);
+ return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
+ }
+ const_iterator find(const Key& k) const {
+ auto r = rep_.Find(k);
+ return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
+ }
+
+ Val& at(const Key& k) {
+ auto r = rep_.Find(k);
+ DCHECK(r.found);
+ return r.b->val(r.index);
+ }
+ const Val& at(const Key& k) const {
+ auto r = rep_.Find(k);
+ DCHECK(r.found);
+ return r.b->val(r.index);
+ }
+
+ template <typename P>
+ std::pair<iterator, bool> insert(const P& p) {
+ return Insert(p.first, p.second);
+ }
+ std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
+ return Insert(p.first, p.second);
+ }
+ template <typename InputIter>
+ void insert(InputIter first, InputIter last) {
+ for (; first != last; ++first) {
+ insert(*first);
+ }
+ }
+
+ Val& operator[](const Key& k) { return IndexOp(k); }
+ Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
+
+ template <typename... Args>
+ std::pair<iterator, bool> emplace(Args&&... args) {
+ return InsertPair(std::make_pair(std::forward<Args>(args)...));
+ }
+
+ size_t erase(const Key& k) {
+ auto r = rep_.Find(k);
+ if (!r.found) return 0;
+ rep_.Erase(r.b, r.index);
+ return 1;
+ }
+ iterator erase(iterator pos) {
+ rep_.Erase(pos.b_, pos.i_);
+ ++pos;
+ return pos;
+ }
+ iterator erase(iterator pos, iterator last) {
+ for (; pos != last; ++pos) {
+ rep_.Erase(pos.b_, pos.i_);
+ }
+ return pos;
+ }
+
+ std::pair<iterator, iterator> equal_range(const Key& k) {
+ auto pos = find(k);
+ if (pos == end()) {
+ return std::make_pair(pos, pos);
+ } else {
+ auto next = pos;
+ ++next;
+ return std::make_pair(pos, next);
+ }
+ }
+ std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
+ auto pos = find(k);
+ if (pos == end()) {
+ return std::make_pair(pos, pos);
+ } else {
+ auto next = pos;
+ ++next;
+ return std::make_pair(pos, next);
+ }
+ }
+
+ bool operator==(const FlatMap& x) const {
+ if (size() != x.size()) return false;
+ for (auto& p : x) {
+ auto i = find(p.first);
+ if (i == end()) return false;
+ if (i->second != p.second) return false;
+ }
+ return true;
+ }
+ bool operator!=(const FlatMap& x) const { return !(*this == x); }
+
+ // If key exists in the table, prefetch the associated value. This
+ // is a hint, and may have no effect.
+ void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
+
+ private:
+ using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
+
+ // Bucket stores kWidth <marker, key, value> triples.
+ // The data is organized as three parallel arrays to reduce padding.
+ struct Bucket {
+ uint8 marker[Rep::kWidth];
+
+ // Wrap keys and values in union to control construction and destruction.
+ union Storage {
+ struct {
+ Key key[Rep::kWidth];
+ Val val[Rep::kWidth];
+ };
+ Storage() {}
+ ~Storage() {}
+ } storage;
+
+ Key& key(uint32 i) {
+ DCHECK_GE(marker[i], 2);
+ return storage.key[i];
+ }
+ Val& val(uint32 i) {
+ DCHECK_GE(marker[i], 2);
+ return storage.val[i];
+ }
+ template <typename V>
+ void InitVal(uint32 i, V&& v) {
+ new (&storage.val[i]) Val(std::forward<V>(v));
+ }
+ void Destroy(uint32 i) {
+ storage.key[i].Key::~Key();
+ storage.val[i].Val::~Val();
+ }
+ void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
+ new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
+ new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
+ }
+ void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
+ new (&storage.key[i]) Key(src->storage.key[src_index]);
+ new (&storage.val[i]) Val(src->storage.val[src_index]);
+ }
+ };
+
+ template <typename Pair>
+ std::pair<iterator, bool> InsertPair(Pair&& p) {
+ return Insert(std::forward<decltype(p.first)>(p.first),
+ std::forward<decltype(p.second)>(p.second));
+ }
+
+ template <typename K, typename V>
+ std::pair<iterator, bool> Insert(K&& k, V&& v) {
+ rep_.MaybeResize();
+ auto r = rep_.FindOrInsert(std::forward<K>(k));
+ const bool inserted = !r.found;
+ if (inserted) {
+ r.b->InitVal(r.index, std::forward<V>(v));
+ }
+ return {iterator(r.b, rep_.limit(), r.index), inserted};
+ }
+
+ template <typename K>
+ Val& IndexOp(K&& k) {
+ rep_.MaybeResize();
+ auto r = rep_.FindOrInsert(std::forward<K>(k));
+ Val* vptr = &r.b->val(r.index);
+ if (!r.found) {
+ new (vptr) Val(); // Initialize value in new slot.
+ }
+ return *vptr;
+ }
+
+ Rep rep_;
+};
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
diff --git a/tensorflow/core/lib/gtl/flatmap_test.cc b/tensorflow/core/lib/gtl/flatmap_test.cc
new file mode 100644
index 0000000000..2fa610b7e1
--- /dev/null
+++ b/tensorflow/core/lib/gtl/flatmap_test.cc
@@ -0,0 +1,576 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+typedef FlatMap<int64, int32, HashInt64> NumMap;
+
+// If map has an entry for k, return the corresponding value, else return def.
+int32 Get(const NumMap& map, int64 k, int32 def = -1) {
+ auto iter = map.find(k);
+ if (iter == map.end()) {
+ EXPECT_EQ(map.count(k), 0);
+ return def;
+ } else {
+ EXPECT_EQ(map.count(k), 1);
+ EXPECT_EQ(&map.at(k), &iter->second);
+ EXPECT_EQ(iter->first, k);
+ return iter->second;
+ }
+}
+
+// Return contents of map as a sorted list of pairs.
+typedef std::vector<std::pair<int64, int32>> NumMapContents;
+NumMapContents Contents(const NumMap& map) {
+ NumMapContents result;
+ for (const auto& p : map) {
+ result.push_back({p.first, p.second});
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+}
+
+// Fill entries with keys [start,limit).
+void Fill(NumMap* map, int64 start, int64 limit) {
+ for (int64 i = start; i < limit; i++) {
+ map->insert({i, i * 100});
+ }
+}
+
+TEST(FlatMapTest, Find) {
+ NumMap map;
+ EXPECT_EQ(Get(map, 1), -1);
+ map.insert({1, 100});
+ map.insert({2, 200});
+ EXPECT_EQ(Get(map, 1), 100);
+ EXPECT_EQ(Get(map, 2), 200);
+ EXPECT_EQ(Get(map, 3), -1);
+}
+
+TEST(FlatMapTest, Insert) {
+ NumMap map;
+ EXPECT_EQ(Get(map, 1), -1);
+
+ // New entry.
+ auto result = map.insert({1, 100});
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(result.first->first, 1);
+ EXPECT_EQ(result.first->second, 100);
+ EXPECT_EQ(Get(map, 1), 100);
+
+ // Attempt to insert over existing entry.
+ result = map.insert({1, 200});
+ EXPECT_FALSE(result.second);
+ EXPECT_EQ(result.first->first, 1);
+ EXPECT_EQ(result.first->second, 100);
+ EXPECT_EQ(Get(map, 1), 100);
+
+ // Overwrite through iterator.
+ result.first->second = 300;
+ EXPECT_EQ(result.first->second, 300);
+ EXPECT_EQ(Get(map, 1), 300);
+
+ // Should get updated value.
+ result = map.insert({1, 400});
+ EXPECT_FALSE(result.second);
+ EXPECT_EQ(result.first->first, 1);
+ EXPECT_EQ(result.first->second, 300);
+ EXPECT_EQ(Get(map, 1), 300);
+}
+
+TEST(FlatMapTest, InsertGrowth) {
+ NumMap map;
+ const int n = 100;
+ Fill(&map, 0, 100);
+ EXPECT_EQ(map.size(), n);
+ for (int i = 0; i < n; i++) {
+ EXPECT_EQ(Get(map, i), i * 100) << i;
+ }
+}
+
+TEST(FlatMapTest, Emplace) {
+ NumMap map;
+
+ // New entry.
+ auto result = map.emplace(1, 100);
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(result.first->first, 1);
+ EXPECT_EQ(result.first->second, 100);
+ EXPECT_EQ(Get(map, 1), 100);
+
+ // Attempt to insert over existing entry.
+ result = map.emplace(1, 200);
+ EXPECT_FALSE(result.second);
+ EXPECT_EQ(result.first->first, 1);
+ EXPECT_EQ(result.first->second, 100);
+ EXPECT_EQ(Get(map, 1), 100);
+
+ // Overwrite through iterator.
+ result.first->second = 300;
+ EXPECT_EQ(result.first->second, 300);
+ EXPECT_EQ(Get(map, 1), 300);
+
+ // Update a second value
+ result = map.emplace(2, 400);
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(result.first->first, 2);
+ EXPECT_EQ(result.first->second, 400);
+ EXPECT_EQ(Get(map, 2), 400);
+}
+
+TEST(FlatMapTest, EmplaceUniquePtr) {
+ FlatMap<int64, std::unique_ptr<string>, HashInt64> smap;
+ smap.emplace(1, std::unique_ptr<string>(new string("hello")));
+}
+
+TEST(FlatMapTest, Size) {
+ NumMap map;
+ EXPECT_EQ(map.size(), 0);
+
+ map.insert({1, 100});
+ map.insert({2, 200});
+ EXPECT_EQ(map.size(), 2);
+}
+
+TEST(FlatMapTest, Empty) {
+ NumMap map;
+ EXPECT_TRUE(map.empty());
+
+ map.insert({1, 100});
+ map.insert({2, 200});
+ EXPECT_FALSE(map.empty());
+}
+
+TEST(FlatMapTest, ArrayOperator) {
+ NumMap map;
+
+ // Create new element if not found.
+ auto v1 = &map[1];
+ EXPECT_EQ(*v1, 0);
+ EXPECT_EQ(Get(map, 1), 0);
+
+ // Write through returned reference.
+ *v1 = 100;
+ EXPECT_EQ(map[1], 100);
+ EXPECT_EQ(Get(map, 1), 100);
+
+ // Reuse existing element if found.
+ auto v1a = &map[1];
+ EXPECT_EQ(v1, v1a);
+ EXPECT_EQ(*v1, 100);
+
+ // Create another element.
+ map[2] = 200;
+ EXPECT_EQ(Get(map, 1), 100);
+ EXPECT_EQ(Get(map, 2), 200);
+}
+
+TEST(FlatMapTest, Count) {
+ NumMap map;
+ EXPECT_EQ(map.count(1), 0);
+ EXPECT_EQ(map.count(2), 0);
+
+ map.insert({1, 100});
+ EXPECT_EQ(map.count(1), 1);
+ EXPECT_EQ(map.count(2), 0);
+
+ map.insert({2, 200});
+ EXPECT_EQ(map.count(1), 1);
+ EXPECT_EQ(map.count(2), 1);
+}
+
+TEST(FlatMapTest, Iter) {
+ NumMap map;
+ EXPECT_EQ(Contents(map), NumMapContents());
+
+ map.insert({1, 100});
+ map.insert({2, 200});
+ EXPECT_EQ(Contents(map), NumMapContents({{1, 100}, {2, 200}}));
+}
+
+TEST(FlatMapTest, Erase) {
+ NumMap map;
+ EXPECT_EQ(map.erase(1), 0);
+ map[1] = 100;
+ map[2] = 200;
+ EXPECT_EQ(map.erase(3), 0);
+ EXPECT_EQ(map.erase(1), 1);
+ EXPECT_EQ(map.size(), 1);
+ EXPECT_EQ(Get(map, 2), 200);
+ EXPECT_EQ(Contents(map), NumMapContents({{2, 200}}));
+ EXPECT_EQ(map.erase(2), 1);
+ EXPECT_EQ(Contents(map), NumMapContents());
+}
+
+TEST(FlatMapTest, EraseIter) {
+ NumMap map;
+ Fill(&map, 1, 11);
+ size_t size = 10;
+ for (auto iter = map.begin(); iter != map.end();) {
+ iter = map.erase(iter);
+ size--;
+ EXPECT_EQ(map.size(), size);
+ }
+ EXPECT_EQ(Contents(map), NumMapContents());
+}
+
+TEST(FlatMapTest, EraseIterPair) {
+ NumMap map;
+ Fill(&map, 1, 11);
+ NumMap expected;
+ auto p1 = map.begin();
+ expected.insert(*p1);
+ ++p1;
+ expected.insert(*p1);
+ ++p1;
+ auto p2 = map.end();
+ EXPECT_EQ(map.erase(p1, p2), map.end());
+ EXPECT_EQ(map.size(), 2);
+ EXPECT_EQ(Contents(map), Contents(expected));
+}
+
+TEST(FlatMapTest, EraseLongChains) {
+ // Make a map with lots of elements and erase a bunch of them to ensure
+ // that we are likely to hit them on future lookups.
+ NumMap map;
+ const int num = 128;
+ Fill(&map, 0, num);
+ for (int i = 0; i < num; i += 3) {
+ EXPECT_EQ(map.erase(i), 1);
+ }
+ for (int i = 0; i < num; i++) {
+ if ((i % 3) != 0) {
+ EXPECT_EQ(Get(map, i), i * 100);
+ } else {
+ EXPECT_EQ(map.count(i), 0);
+ }
+ }
+
+ // Erase remainder to trigger table shrinking.
+ const size_t orig_buckets = map.bucket_count();
+ for (int i = 0; i < num; i++) {
+ map.erase(i);
+ }
+ EXPECT_TRUE(map.empty());
+ EXPECT_EQ(map.bucket_count(), orig_buckets);
+ map[1] = 100; // Actual shrinking is triggered by an insert.
+ EXPECT_LT(map.bucket_count(), orig_buckets);
+}
+
+TEST(FlatMap, AlternatingInsertRemove) {
+ NumMap map;
+ map.insert({1000, 1000});
+ map.insert({2000, 1000});
+ map.insert({3000, 1000});
+ for (int i = 0; i < 10000; i++) {
+ map.insert({i, i});
+ map.erase(i);
+ }
+}
+
+TEST(FlatMap, ClearNoResize) {
+ NumMap map;
+ Fill(&map, 0, 100);
+ const size_t orig = map.bucket_count();
+ map.clear_no_resize();
+ EXPECT_EQ(map.size(), 0);
+ EXPECT_EQ(Contents(map), NumMapContents());
+ EXPECT_EQ(map.bucket_count(), orig);
+}
+
+TEST(FlatMap, Clear) {
+ NumMap map;
+ Fill(&map, 0, 100);
+ const size_t orig = map.bucket_count();
+ map.clear();
+ EXPECT_EQ(map.size(), 0);
+ EXPECT_EQ(Contents(map), NumMapContents());
+ EXPECT_LT(map.bucket_count(), orig);
+}
+
+TEST(FlatMap, Copy) {
+ for (int n = 0; n < 10; n++) {
+ NumMap src;
+ Fill(&src, 0, n);
+ NumMap copy = src;
+ EXPECT_EQ(Contents(src), Contents(copy));
+ NumMap copy2;
+ copy2 = src;
+ EXPECT_EQ(Contents(src), Contents(copy2));
+ copy2 = copy2; // Self-assignment
+ EXPECT_EQ(Contents(src), Contents(copy2));
+ }
+}
+
+TEST(FlatMap, InitFromIter) {
+ for (int n = 0; n < 10; n++) {
+ NumMap src;
+ Fill(&src, 0, n);
+ auto vec = Contents(src);
+ NumMap dst(vec.begin(), vec.end());
+ EXPECT_EQ(Contents(dst), vec);
+ }
+}
+
+TEST(FlatMap, InsertIter) {
+ NumMap a, b;
+ Fill(&a, 1, 10);
+ Fill(&b, 8, 20);
+ b[9] = 10000; // Should not get inserted into a since a already has 9
+ a.insert(b.begin(), b.end());
+ NumMap expected;
+ Fill(&expected, 1, 20);
+ EXPECT_EQ(Contents(a), Contents(expected));
+}
+
+TEST(FlatMap, Eq) {
+ NumMap empty;
+
+ NumMap elems;
+ Fill(&elems, 0, 5);
+ EXPECT_FALSE(empty == elems);
+ EXPECT_TRUE(empty != elems);
+
+ NumMap copy = elems;
+ EXPECT_TRUE(copy == elems);
+ EXPECT_FALSE(copy != elems);
+
+ NumMap changed = elems;
+ changed[3] = 1;
+ EXPECT_FALSE(changed == elems);
+ EXPECT_TRUE(changed != elems);
+
+ NumMap changed2 = elems;
+ changed2.erase(3);
+ EXPECT_FALSE(changed2 == elems);
+ EXPECT_TRUE(changed2 != elems);
+}
+
+TEST(FlatMap, Swap) {
+ NumMap a, b;
+ Fill(&a, 1, 5);
+ Fill(&b, 100, 200);
+ NumMap c = a;
+ NumMap d = b;
+ EXPECT_EQ(c, a);
+ EXPECT_EQ(d, b);
+ c.swap(d);
+ EXPECT_EQ(c, b);
+ EXPECT_EQ(d, a);
+}
+
+TEST(FlatMap, Reserve) {
+ NumMap src;
+ Fill(&src, 1, 100);
+ NumMap a = src;
+ a.reserve(10);
+ EXPECT_EQ(a, src);
+ NumMap b = src;
+ b.rehash(1000);
+ EXPECT_EQ(b, src);
+}
+
+TEST(FlatMap, EqualRangeMutable) {
+ NumMap map;
+ Fill(&map, 1, 10);
+
+ // Existing element
+ auto p1 = map.equal_range(3);
+ EXPECT_TRUE(p1.first != p1.second);
+ EXPECT_EQ(p1.first->first, 3);
+ EXPECT_EQ(p1.first->second, 300);
+ ++p1.first;
+ EXPECT_TRUE(p1.first == p1.second);
+
+ // Missing element
+ auto p2 = map.equal_range(100);
+ EXPECT_TRUE(p2.first == p2.second);
+}
+
+TEST(FlatMap, EqualRangeConst) {
+ NumMap tmp;
+ Fill(&tmp, 1, 10);
+
+ const NumMap map = tmp;
+
+ // Existing element
+ auto p1 = map.equal_range(3);
+ EXPECT_TRUE(p1.first != p1.second);
+ EXPECT_EQ(p1.first->first, 3);
+ EXPECT_EQ(p1.first->second, 300);
+ ++p1.first;
+ EXPECT_TRUE(p1.first == p1.second);
+
+ // Missing element
+ auto p2 = map.equal_range(100);
+ EXPECT_TRUE(p2.first == p2.second);
+}
+
+TEST(FlatMap, Prefetch) {
+ NumMap map;
+ Fill(&map, 0, 1000);
+ // Prefetch present and missing keys.
+ for (int i = 0; i < 2000; i++) {
+ map.prefetch_value(i);
+ }
+}
+
+// Non-copyable values should work.
+struct NC {
+ int64 value;
+ NC() : value(-1) {}
+ NC(int64 v) : value(v) {}
+ NC(const NC& x) : value(x.value) {}
+ bool operator==(const NC& x) const { return value == x.value; }
+};
+struct HashNC {
+ size_t operator()(NC x) const { return x.value; }
+};
+
+TEST(FlatMap, NonCopyable) {
+ FlatMap<NC, NC, HashNC> map;
+ for (int i = 0; i < 100; i++) {
+ map[NC(i)] = NC(i * 100);
+ }
+ for (int i = 0; i < 100; i++) {
+ EXPECT_EQ(map.count(NC(i)), 1);
+ auto iter = map.find(NC(i));
+ EXPECT_NE(iter, map.end());
+ EXPECT_EQ(iter->first, NC(i));
+ EXPECT_EQ(iter->second, NC(i * 100));
+ EXPECT_EQ(map[NC(i)], NC(i * 100));
+ }
+ map.erase(NC(10));
+ EXPECT_EQ(map.count(NC(10)), 0);
+}
+
+// Test with heap-allocated objects so that mismanaged constructions
+// or destructions will show up as errors under a sanitizer or
+// heap checker.
+TEST(FlatMap, ConstructDestruct) {
+ FlatMap<string, string, HashStr> map;
+ string k1 = "the quick brown fox jumped over the lazy dog";
+ string k2 = k1 + k1;
+ string k3 = k1 + k2;
+ map[k1] = k2;
+ map[k3] = k1;
+ EXPECT_EQ(k1, map.find(k1)->first);
+ EXPECT_EQ(k2, map.find(k1)->second);
+ EXPECT_EQ(k1, map[k3]);
+ map.erase(k3);
+ EXPECT_EQ(string(), map[k3]);
+
+ map.clear();
+ map[k1] = k2;
+ EXPECT_EQ(k2, map[k1]);
+
+ map.reserve(100);
+ EXPECT_EQ(k2, map[k1]);
+}
+
+// Type to use to ensure that custom equality operator is used
+// that ignores extra value.
+struct CustomCmpKey {
+ int64 a;
+ int64 b;
+ CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
+ bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
+};
+struct HashA {
+ size_t operator()(CustomCmpKey x) const { return x.a; }
+};
+struct EqA {
+ // Ignore b fields.
+ bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
+};
+TEST(FlatMap, CustomCmp) {
+ FlatMap<CustomCmpKey, int, HashA, EqA> map;
+ map[CustomCmpKey(100, 200)] = 300;
+ EXPECT_EQ(300, map[CustomCmpKey(100, 200)]);
+ EXPECT_EQ(300, map[CustomCmpKey(100, 500)]); // Differences in key.b ignored
+}
+
+// Test unique_ptr handling.
+typedef std::unique_ptr<int> UniqInt;
+static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
+
+struct HashUniq {
+ size_t operator()(const UniqInt& p) const { return *p; }
+};
+struct EqUniq {
+ bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
+};
+typedef FlatMap<UniqInt, UniqInt, HashUniq, EqUniq> UniqMap;
+
+TEST(FlatMap, UniqueMap) {
+ UniqMap map;
+
+ // Fill map
+ const int N = 10;
+ for (int i = 0; i < N; i++) {
+ if ((i % 2) == 0) {
+ map[MakeUniq(i)] = MakeUniq(i + 100);
+ } else {
+ map.emplace(MakeUniq(i), MakeUniq(i + 100));
+ }
+ }
+ EXPECT_EQ(map.size(), N);
+
+ // Lookups
+ for (int i = 0; i < N; i++) {
+ EXPECT_EQ(*map.at(MakeUniq(i)), i + 100);
+ }
+
+ // find+erase
+ EXPECT_EQ(map.count(MakeUniq(2)), 1);
+ map.erase(MakeUniq(2));
+ EXPECT_EQ(map.count(MakeUniq(2)), 0);
+
+ // clear
+ map.clear();
+ EXPECT_EQ(map.size(), 0);
+}
+
+TEST(FlatMap, UniqueMapIter) {
+ UniqMap map;
+ const int kCount = 10;
+ const int kValueDelta = 100;
+ for (int i = 1; i <= kCount; i++) {
+ map[MakeUniq(i)] = MakeUniq(i + kValueDelta);
+ }
+ int key_sum = 0;
+ int val_sum = 0;
+ for (const auto& p : map) {
+ key_sum += *p.first;
+ val_sum += *p.second;
+ }
+ EXPECT_EQ(key_sum, (kCount * (kCount + 1)) / 2);
+ EXPECT_EQ(val_sum, key_sum + (kCount * kValueDelta));
+}
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h
new file mode 100644
index 0000000000..ff590d4128
--- /dev/null
+++ b/tensorflow/core/lib/gtl/flatrep.h
@@ -0,0 +1,332 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
+
+#include <string.h>
+#include <utility>
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace gtl {
+namespace internal {
+
+// Internal representation for FlatMap and FlatSet.
+//
+// The representation is an open-addressed hash table. Conceptually,
+// the representation is a flat array of entries. However we
+// structure it as an array of of buckets where each bucket holds
+// kWidth entries along with metadata for the kWidth entries. The
+// metadata marker is
+//
+// (a) kEmpty: the entry is empty
+// (b) kDeleted: the entry has been deleted
+// (c) other: the entry is occupied and has low-8 bits of its hash.
+// These hash bits can be used to avoid potentially expensive
+// key comparisons.
+//
+// FlatMap passes in a bucket that contains keys and values, FlatSet
+// passes in a bucket that does not contain values.
+template <typename Key, typename Bucket, class Hash, class Eq>
+class FlatRep {
+ public:
+ // kWidth is the number of entries stored in a bucket.
+ static const uint32 kBase = 3;
+ static const uint32 kWidth = (1 << kBase);
+
+ FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
+ Init(N);
+ }
+ explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
+ Init(src.size());
+ CopyEntries(src.array_, src.end_, CopyEntry());
+ }
+ ~FlatRep() {
+ clear_no_resize();
+ delete[] array_;
+ }
+
+ // Simple accessors.
+ size_t size() const { return not_empty_ - deleted_; }
+ size_t bucket_count() const { return mask_ + 1; }
+ Bucket* start() const { return array_; }
+ Bucket* limit() const { return end_; }
+ const Hash& hash_function() const { return hash_; }
+ const Eq& key_eq() const { return equal_; }
+
+ // Overwrite contents of *this with contents of src.
+ void CopyFrom(const FlatRep& src) {
+ if (this != &src) {
+ clear_no_resize();
+ delete[] array_;
+ Init(src.size());
+ CopyEntries(src.array_, src.end_, CopyEntry());
+ }
+ }
+
+ void clear_no_resize() {
+ for (Bucket* b = array_; b != end_; b++) {
+ for (uint32 i = 0; i < kWidth; i++) {
+ if (b->marker[i] >= 2) {
+ b->Destroy(i);
+ b->marker[i] = kEmpty;
+ }
+ }
+ }
+ not_empty_ = 0;
+ deleted_ = 0;
+ }
+
+ void clear() {
+ clear_no_resize();
+ grow_ = 0; // Consider shrinking in MaybeResize()
+ MaybeResize();
+ }
+
+ void swap(FlatRep& x) {
+ using std::swap;
+ swap(array_, x.array_);
+ swap(end_, x.end_);
+ swap(lglen_, x.lglen_);
+ swap(mask_, x.mask_);
+ swap(not_empty_, x.not_empty_);
+ swap(deleted_, x.deleted_);
+ swap(grow_, x.grow_);
+ swap(shrink_, x.shrink_);
+ }
+
+ struct SearchResult {
+ bool found;
+ Bucket* b;
+ uint32 index;
+ };
+
+ // Hash value is partitioned as follows:
+ // 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
+ // 2. Next 3 bits give index inside bucket.
+ // 3. Remaining bits give bucket number.
+
+ // Find bucket/index for key k.
+ SearchResult Find(const Key& k) const {
+ size_t h = hash_(k);
+ const uint32 marker = Marker(h & 0xff);
+ size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
+ uint32 num_probes = 1; // Needed for quadratic probing
+ while (true) {
+ uint32 bi = index & (kWidth - 1);
+ Bucket* b = &array_[index >> kBase];
+ const uint32 x = b->marker[bi];
+ if (x == marker && equal_(b->key(bi), k)) {
+ return {true, b, bi};
+ } else if (x == kEmpty) {
+ return {false, nullptr, 0};
+ }
+ // Quadratic probing.
+ index = (index + num_probes) & mask_;
+ num_probes++;
+ }
+ }
+
+ // Find bucket/index for key k, creating a new one if necessary.
+ //
+ // KeyType is a template parameter so that k's type is deduced and it
+ // becomes a universal reference which allows the key initialization
+ // below to use an rvalue constructor if available.
+ template <typename KeyType>
+ SearchResult FindOrInsert(KeyType&& k) {
+ size_t h = hash_(k);
+ const uint32 marker = Marker(h & 0xff);
+ size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
+ uint32 num_probes = 1; // Needed for quadratic probing
+ Bucket* del = nullptr; // First encountered deletion for kInsert
+ uint32 di = 0;
+ while (true) {
+ uint32 bi = index & (kWidth - 1);
+ Bucket* b = &array_[index >> kBase];
+ const uint32 x = b->marker[bi];
+ if (x == marker && equal_(b->key(bi), k)) {
+ return {true, b, bi};
+ } else if (!del && x == kDeleted) {
+ // Remember deleted index to use for insertion.
+ del = b;
+ di = bi;
+ } else if (x == kEmpty) {
+ if (del) {
+ // Store in the first deleted slot we encountered
+ b = del;
+ bi = di;
+ deleted_--; // not_empty_ does not change
+ } else {
+ not_empty_++;
+ }
+ b->marker[bi] = marker;
+ new (&b->key(bi)) Key(std::forward<KeyType>(k));
+ return {false, b, bi};
+ }
+ // Quadratic probing.
+ index = (index + num_probes) & mask_;
+ num_probes++;
+ }
+ }
+
+ void Erase(Bucket* b, uint32 i) {
+ b->Destroy(i);
+ b->marker[i] = kDeleted;
+ deleted_++;
+ grow_ = 0; // Consider shrinking on next insert
+ }
+
+ void Prefetch(const Key& k) const {
+ size_t h = hash_(k);
+ size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
+ uint32 bi = index & (kWidth - 1);
+ Bucket* b = &array_[index >> kBase];
+ prefetch(&b->storage.key[bi]);
+ }
+ void prefetch(const void* ptr) const {
+ // TODO(jeff,sanjay): Remove this routine when we add a
+ // prefetch(...) call to platform so that the Prefetch routine
+ // actually does something
+ }
+
+ inline void MaybeResize() {
+ if (not_empty_ < grow_) {
+ return; // Nothing to do
+ }
+ if (grow_ == 0) {
+ // Special value set by erase to cause shrink on next insert.
+ if (size() >= shrink_) {
+ // Not small enough to shrink.
+ grow_ = static_cast<size_t>(bucket_count() * 0.8);
+ if (not_empty_ < grow_) return;
+ }
+ }
+ Resize(size() + 1);
+ }
+
+ void Resize(size_t N) {
+ Bucket* old = array_;
+ Bucket* old_end = end_;
+ Init(N);
+ CopyEntries(old, old_end, MoveEntry());
+ delete[] old;
+ }
+
+ private:
+ enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry.
+
+ Hash hash_; // User-supplied hasher
+ Eq equal_; // User-supplied comparator
+ uint8 lglen_; // lg(#buckets)
+ Bucket* array_; // array of length (1 << lglen_)
+ Bucket* end_; // Points just past last bucket in array_
+ size_t mask_; // (# of entries in table) - 1
+ size_t not_empty_; // Count of entries with marker != kEmpty
+ size_t deleted_; // Count of entries with marker == kDeleted
+ size_t grow_; // Grow array when not_empty_ >= grow_
+ size_t shrink_; // Shrink array when size() < shrink_
+
+ // Avoid kEmpty and kDeleted markers when computing hash values to
+ // store in Bucket::marker[].
+ static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
+
+ void Init(size_t N) {
+ // Make enough room for N elements.
+ size_t lg = 0; // Smallest table is just one bucket.
+ while (N >= 0.8 * ((1 << lg) * kWidth)) {
+ lg++;
+ }
+ const size_t n = (1 << lg);
+ Bucket* array = new Bucket[n];
+ for (size_t i = 0; i < n; i++) {
+ Bucket* b = &array[i];
+ memset(b->marker, kEmpty, kWidth);
+ }
+ const size_t capacity = (1 << lg) * kWidth;
+ lglen_ = lg;
+ mask_ = capacity - 1;
+ array_ = array;
+ end_ = array + n;
+ not_empty_ = 0;
+ deleted_ = 0;
+ grow_ = static_cast<size_t>(capacity * 0.8);
+ if (lg == 0) {
+ // Already down to one bucket; no more shrinking.
+ shrink_ = 0;
+ } else {
+ shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5
+ }
+ }
+
+ // Used by FreshInsert when we should copy from source.
+ struct CopyEntry {
+ inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
+ dst->CopyFrom(dsti, src, srci);
+ }
+ };
+
+ // Used by FreshInsert when we should move from source.
+ struct MoveEntry {
+ inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
+ dst->MoveFrom(dsti, src, srci);
+ src->Destroy(srci);
+ src->marker[srci] = kDeleted;
+ }
+ };
+
+ template <typename Copier>
+ void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
+ for (Bucket* b = start; b != end; b++) {
+ for (uint32 i = 0; i < kWidth; i++) {
+ if (b->marker[i] >= 2) {
+ FreshInsert(b, i, copier);
+ }
+ }
+ }
+ }
+
+ // Create an entry for the key numbered src_index in *src and return
+ // its bucket/index. Used for insertion into a fresh table. We
+ // assume that there are no deletions, and k does not already exist
+ // in the table.
+ template <typename Copier>
+ void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
+ size_t h = hash_(src->key(src_index));
+ const uint32 marker = Marker(h & 0xff);
+ size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
+ uint32 num_probes = 1; // Needed for quadratic probing
+ while (true) {
+ uint32 bi = index & (kWidth - 1);
+ Bucket* b = &array_[index >> kBase];
+ const uint32 x = b->marker[bi];
+ if (x == 0) {
+ b->marker[bi] = marker;
+ not_empty_++;
+ copier(b, bi, src, src_index);
+ return;
+ }
+ // Quadratic probing.
+ index = (index + num_probes) & mask_;
+ num_probes++;
+ }
+ }
+};
+
+} // namespace internal
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h
new file mode 100644
index 0000000000..b94d88cbc6
--- /dev/null
+++ b/tensorflow/core/lib/gtl/flatset.h
@@ -0,0 +1,277 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
+
+#include <stddef.h>
+#include <utility>
+#include "tensorflow/core/lib/gtl/flatrep.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// FlatSet<K,...> provides a set of K.
+//
+// The map is implemented using an open-addressed hash table. A
+// single array holds entire map contents and collisions are resolved
+// by probing at a sequence of locations in the array.
+template <typename Key, class Hash, class Eq = std::equal_to<Key>>
+class FlatSet {
+ private:
+ // Forward declare some internal types needed in public section.
+ struct Bucket;
+
+ public:
+ typedef Key key_type;
+ typedef Key value_type;
+ typedef Hash hasher;
+ typedef Eq key_equal;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+ typedef value_type* pointer;
+ typedef const value_type* const_pointer;
+ typedef value_type& reference;
+ typedef const value_type& const_reference;
+
+ FlatSet() : FlatSet(1) {}
+
+ explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
+ : rep_(N, hf, eq) {}
+
+ FlatSet(const FlatSet& src) : rep_(src.rep_) {}
+
+ template <typename InputIter>
+ FlatSet(InputIter first, InputIter last, size_t N = 1,
+ const Hash& hf = Hash(), const Eq& eq = Eq())
+ : FlatSet(N, hf, eq) {
+ insert(first, last);
+ }
+
+ FlatSet& operator=(const FlatSet& src) {
+ rep_.CopyFrom(src.rep_);
+ return *this;
+ }
+
+ ~FlatSet() {}
+
+ void swap(FlatSet& x) { rep_.swap(x.rep_); }
+ void clear_no_resize() { rep_.clear_no_resize(); }
+ void clear() { rep_.clear(); }
+ void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
+ void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
+ void resize(size_t N) { rep_.Resize(std::max(N, size())); }
+ size_t size() const { return rep_.size(); }
+ bool empty() const { return size() == 0; }
+ size_t bucket_count() const { return rep_.bucket_count(); }
+ hasher hash_function() const { return rep_.hash_function(); }
+ key_equal key_eq() const { return rep_.key_eq(); }
+
+ class iterator {
+ public:
+ iterator() : b_(nullptr), end_(nullptr), i_(0) {}
+
+ // Make iterator pointing at first element at or after b.
+ explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
+ SkipUnused();
+ }
+
+ // Make iterator pointing exactly at ith element in b, which must exist.
+ iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {}
+
+ Key& operator*() { return key(); }
+ Key* operator->() { return &key(); }
+ bool operator==(const iterator& x) const {
+ return b_ == x.b_ && i_ == x.i_;
+ }
+ bool operator!=(const iterator& x) const { return !(*this == x); }
+ iterator& operator++() {
+ DCHECK(b_ != end_);
+ i_++;
+ SkipUnused();
+ return *this;
+ }
+
+ private:
+ friend class FlatSet;
+ Bucket* b_;
+ Bucket* end_;
+ uint32 i_;
+
+ Key& key() const { return b_->key(i_); }
+ void SkipUnused() {
+ while (b_ < end_) {
+ if (i_ >= Rep::kWidth) {
+ i_ = 0;
+ b_++;
+ } else if (b_->marker[i_] < 2) {
+ i_++;
+ } else {
+ break;
+ }
+ }
+ }
+ };
+
+ class const_iterator {
+ private:
+ mutable iterator rep_; // Share state and logic with non-const iterator.
+ public:
+ const_iterator() : rep_() {}
+ explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
+ const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
+
+ const Key& operator*() const { return rep_.key(); }
+ const Key* operator->() const { return &rep_.key(); }
+ bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
+ bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
+ const_iterator& operator++() {
+ ++rep_;
+ return *this;
+ }
+ };
+
+ iterator begin() { return iterator(rep_.start(), rep_.limit()); }
+ iterator end() { return iterator(rep_.limit(), rep_.limit()); }
+ const_iterator begin() const {
+ return const_iterator(rep_.start(), rep_.limit());
+ }
+ const_iterator end() const {
+ return const_iterator(rep_.limit(), rep_.limit());
+ }
+
+ size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
+ iterator find(const Key& k) {
+ auto r = rep_.Find(k);
+ return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
+ }
+ const_iterator find(const Key& k) const {
+ auto r = rep_.Find(k);
+ return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
+ }
+
+ std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
+ template <typename InputIter>
+ void insert(InputIter first, InputIter last) {
+ for (; first != last; ++first) {
+ insert(*first);
+ }
+ }
+
+ template <typename... Args>
+ std::pair<iterator, bool> emplace(Args&&... args) {
+ rep_.MaybeResize();
+ auto r = rep_.FindOrInsert(std::forward<Args>(args)...);
+ const bool inserted = !r.found;
+ return {iterator(r.b, rep_.limit(), r.index), inserted};
+ }
+
+ size_t erase(const Key& k) {
+ auto r = rep_.Find(k);
+ if (!r.found) return 0;
+ rep_.Erase(r.b, r.index);
+ return 1;
+ }
+ iterator erase(iterator pos) {
+ rep_.Erase(pos.b_, pos.i_);
+ ++pos;
+ return pos;
+ }
+ iterator erase(iterator pos, iterator last) {
+ for (; pos != last; ++pos) {
+ rep_.Erase(pos.b_, pos.i_);
+ }
+ return pos;
+ }
+
+ std::pair<iterator, iterator> equal_range(const Key& k) {
+ auto pos = find(k);
+ if (pos == end()) {
+ return std::make_pair(pos, pos);
+ } else {
+ auto next = pos;
+ ++next;
+ return std::make_pair(pos, next);
+ }
+ }
+ std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
+ auto pos = find(k);
+ if (pos == end()) {
+ return std::make_pair(pos, pos);
+ } else {
+ auto next = pos;
+ ++next;
+ return std::make_pair(pos, next);
+ }
+ }
+
+ bool operator==(const FlatSet& x) const {
+ if (size() != x.size()) return false;
+ for (const auto& elem : x) {
+ auto i = find(elem);
+ if (i == end()) return false;
+ }
+ return true;
+ }
+ bool operator!=(const FlatSet& x) const { return !(*this == x); }
+
+ // If key exists in the table, prefetch it. This is a hint, and may
+ // have no effect.
+ void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
+
+ private:
+ using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
+
+ // Bucket stores kWidth <marker, key, value> triples.
+ // The data is organized as three parallel arrays to reduce padding.
+ struct Bucket {
+ uint8 marker[Rep::kWidth];
+
+ // Wrap keys in union to control construction and destruction.
+ union Storage {
+ Key key[Rep::kWidth];
+ Storage() {}
+ ~Storage() {}
+ } storage;
+
+ Key& key(uint32 i) {
+ DCHECK_GE(marker[i], 2);
+ return storage.key[i];
+ }
+ void Destroy(uint32 i) { storage.key[i].Key::~Key(); }
+ void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
+ new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
+ }
+ void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
+ new (&storage.key[i]) Key(src->storage.key[src_index]);
+ }
+ };
+
+ std::pair<iterator, bool> Insert(const Key& k) {
+ rep_.MaybeResize();
+ auto r = rep_.FindOrInsert(k);
+ const bool inserted = !r.found;
+ return {iterator(r.b, rep_.limit(), r.index), inserted};
+ }
+
+ Rep rep_;
+};
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
diff --git a/tensorflow/core/lib/gtl/flatset_test.cc b/tensorflow/core/lib/gtl/flatset_test.cc
new file mode 100644
index 0000000000..ea9c9c22b5
--- /dev/null
+++ b/tensorflow/core/lib/gtl/flatset_test.cc
@@ -0,0 +1,501 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+typedef FlatSet<int64, HashInt64> NumSet;
+
+// Returns true iff set has an entry for k.
+// Also verifies that find and count give consistent results.
+bool Has(const NumSet& set, int64 k) {
+ auto iter = set.find(k);
+ if (iter == set.end()) {
+ EXPECT_EQ(set.count(k), 0);
+ return false;
+ } else {
+ EXPECT_EQ(set.count(k), 1);
+ EXPECT_EQ(*iter, k);
+ return true;
+ }
+}
+
+// Return contents of set as a sorted list of numbers.
+typedef std::vector<int64> NumSetContents;
+NumSetContents Contents(const NumSet& set) {
+ NumSetContents result;
+ for (int64 n : set) {
+ result.push_back(n);
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+}
+
+// Fill entries with keys [start,limit).
+void Fill(NumSet* set, int64 start, int64 limit) {
+ for (int64 i = start; i < limit; i++) {
+ set->insert(i);
+ }
+}
+
+TEST(FlatSetTest, Find) {
+ NumSet set;
+ EXPECT_FALSE(Has(set, 1));
+ set.insert(1);
+ set.insert(2);
+ EXPECT_TRUE(Has(set, 1));
+ EXPECT_TRUE(Has(set, 2));
+ EXPECT_FALSE(Has(set, 3));
+}
+
+TEST(FlatSetTest, Insert) {
+ NumSet set;
+ EXPECT_FALSE(Has(set, 1));
+
+ // New entry.
+ auto result = set.insert(1);
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(*result.first, 1);
+ EXPECT_TRUE(Has(set, 1));
+
+ // Attempt to insert over existing entry.
+ result = set.insert(1);
+ EXPECT_FALSE(result.second);
+ EXPECT_EQ(*result.first, 1);
+ EXPECT_TRUE(Has(set, 1));
+}
+
+TEST(FlatSetTest, InsertGrowth) {
+ NumSet set;
+ const int n = 100;
+ Fill(&set, 0, 100);
+ EXPECT_EQ(set.size(), n);
+ for (int i = 0; i < n; i++) {
+ EXPECT_TRUE(Has(set, i)) << i;
+ }
+}
+
+TEST(FlatSetTest, Emplace) {
+ NumSet set;
+
+ // New entry.
+ auto result = set.emplace(73);
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(*result.first, 73);
+ EXPECT_TRUE(Has(set, 73));
+
+ // Attempt to insert an existing entry.
+ result = set.emplace(73);
+ EXPECT_FALSE(result.second);
+ EXPECT_EQ(*result.first, 73);
+ EXPECT_TRUE(Has(set, 73));
+
+ // Add a second value
+ result = set.emplace(103);
+ EXPECT_TRUE(result.second);
+ EXPECT_EQ(*result.first, 103);
+ EXPECT_TRUE(Has(set, 103));
+}
+
+TEST(FlatSetTest, Size) {
+ NumSet set;
+ EXPECT_EQ(set.size(), 0);
+
+ set.insert(1);
+ set.insert(2);
+ EXPECT_EQ(set.size(), 2);
+}
+
+TEST(FlatSetTest, Empty) {
+ NumSet set;
+ EXPECT_TRUE(set.empty());
+
+ set.insert(1);
+ set.insert(2);
+ EXPECT_FALSE(set.empty());
+}
+
+TEST(FlatSetTest, Count) {
+ NumSet set;
+ EXPECT_EQ(set.count(1), 0);
+ EXPECT_EQ(set.count(2), 0);
+
+ set.insert(1);
+ EXPECT_EQ(set.count(1), 1);
+ EXPECT_EQ(set.count(2), 0);
+
+ set.insert(2);
+ EXPECT_EQ(set.count(1), 1);
+ EXPECT_EQ(set.count(2), 1);
+}
+
+TEST(FlatSetTest, Iter) {
+ NumSet set;
+ EXPECT_EQ(Contents(set), NumSetContents());
+
+ set.insert(1);
+ set.insert(2);
+ EXPECT_EQ(Contents(set), NumSetContents({1, 2}));
+}
+
+TEST(FlatSetTest, Erase) {
+ NumSet set;
+ EXPECT_EQ(set.erase(1), 0);
+ set.insert(1);
+ set.insert(2);
+ EXPECT_EQ(set.erase(3), 0);
+ EXPECT_EQ(set.erase(1), 1);
+ EXPECT_EQ(set.size(), 1);
+ EXPECT_TRUE(Has(set, 2));
+ EXPECT_EQ(Contents(set), NumSetContents({2}));
+ EXPECT_EQ(set.erase(2), 1);
+ EXPECT_EQ(Contents(set), NumSetContents());
+}
+
+TEST(FlatSetTest, EraseIter) {
+ NumSet set;
+ Fill(&set, 1, 11);
+ size_t size = 10;
+ for (auto iter = set.begin(); iter != set.end();) {
+ iter = set.erase(iter);
+ size--;
+ EXPECT_EQ(set.size(), size);
+ }
+ EXPECT_EQ(Contents(set), NumSetContents());
+}
+
+TEST(FlatSetTest, EraseIterPair) {
+ NumSet set;
+ Fill(&set, 1, 11);
+ NumSet expected;
+ auto p1 = set.begin();
+ expected.insert(*p1);
+ ++p1;
+ expected.insert(*p1);
+ ++p1;
+ auto p2 = set.end();
+ EXPECT_EQ(set.erase(p1, p2), set.end());
+ EXPECT_EQ(set.size(), 2);
+ EXPECT_EQ(Contents(set), Contents(expected));
+}
+
+TEST(FlatSetTest, EraseLongChains) {
+ // Make a set with lots of elements and erase a bunch of them to ensure
+ // that we are likely to hit them on future lookups.
+ NumSet set;
+ const int num = 128;
+ Fill(&set, 0, num);
+ for (int i = 0; i < num; i += 3) {
+ EXPECT_EQ(set.erase(i), 1);
+ }
+ for (int i = 0; i < num; i++) {
+ // Multiples of 3 should be not present.
+ EXPECT_EQ(Has(set, i), ((i % 3) != 0)) << i;
+ }
+
+ // Erase remainder to trigger table shrinking.
+ const size_t orig_buckets = set.bucket_count();
+ for (int i = 0; i < num; i++) {
+ set.erase(i);
+ }
+ EXPECT_TRUE(set.empty());
+ EXPECT_EQ(set.bucket_count(), orig_buckets);
+ set.insert(1); // Actual shrinking is triggered by an insert.
+ EXPECT_LT(set.bucket_count(), orig_buckets);
+}
+
+TEST(FlatSet, ClearNoResize) {
+ NumSet set;
+ Fill(&set, 0, 100);
+ const size_t orig = set.bucket_count();
+ set.clear_no_resize();
+ EXPECT_EQ(set.size(), 0);
+ EXPECT_EQ(Contents(set), NumSetContents());
+ EXPECT_EQ(set.bucket_count(), orig);
+}
+
+TEST(FlatSet, Clear) {
+ NumSet set;
+ Fill(&set, 0, 100);
+ const size_t orig = set.bucket_count();
+ set.clear();
+ EXPECT_EQ(set.size(), 0);
+ EXPECT_EQ(Contents(set), NumSetContents());
+ EXPECT_LT(set.bucket_count(), orig);
+}
+
+TEST(FlatSet, Copy) {
+ for (int n = 0; n < 10; n++) {
+ NumSet src;
+ Fill(&src, 0, n);
+ NumSet copy = src;
+ EXPECT_EQ(Contents(src), Contents(copy));
+ NumSet copy2;
+ copy2 = src;
+ EXPECT_EQ(Contents(src), Contents(copy2));
+ copy2 = copy2; // Self-assignment
+ EXPECT_EQ(Contents(src), Contents(copy2));
+ }
+}
+
+TEST(FlatSet, InitFromIter) {
+ for (int n = 0; n < 10; n++) {
+ NumSet src;
+ Fill(&src, 0, n);
+ auto vec = Contents(src);
+ NumSet dst(vec.begin(), vec.end());
+ EXPECT_EQ(Contents(dst), vec);
+ }
+}
+
+TEST(FlatSet, InsertIter) {
+ NumSet a, b;
+ Fill(&a, 1, 10);
+ Fill(&b, 8, 20);
+ b.insert(9); // Should not get inserted into a since a already has 9
+ a.insert(b.begin(), b.end());
+ NumSet expected;
+ Fill(&expected, 1, 20);
+ EXPECT_EQ(Contents(a), Contents(expected));
+}
+
+TEST(FlatSet, Eq) {
+ NumSet empty;
+
+ NumSet elems;
+ Fill(&elems, 0, 5);
+ EXPECT_FALSE(empty == elems);
+ EXPECT_TRUE(empty != elems);
+
+ NumSet copy = elems;
+ EXPECT_TRUE(copy == elems);
+ EXPECT_FALSE(copy != elems);
+
+ NumSet changed = elems;
+ changed.insert(7);
+ EXPECT_FALSE(changed == elems);
+ EXPECT_TRUE(changed != elems);
+
+ NumSet changed2 = elems;
+ changed2.erase(3);
+ EXPECT_FALSE(changed2 == elems);
+ EXPECT_TRUE(changed2 != elems);
+}
+
+TEST(FlatSet, Swap) {
+ NumSet a, b;
+ Fill(&a, 1, 5);
+ Fill(&b, 100, 200);
+ NumSet c = a;
+ NumSet d = b;
+ EXPECT_EQ(c, a);
+ EXPECT_EQ(d, b);
+ c.swap(d);
+ EXPECT_EQ(c, b);
+ EXPECT_EQ(d, a);
+}
+
+TEST(FlatSet, Reserve) {
+ NumSet src;
+ Fill(&src, 1, 100);
+ NumSet a = src;
+ a.reserve(10);
+ EXPECT_EQ(a, src);
+ NumSet b = src;
+ b.rehash(1000);
+ EXPECT_EQ(b, src);
+}
+
+TEST(FlatSet, EqualRangeMutable) {
+ NumSet set;
+ Fill(&set, 1, 10);
+
+ // Existing element
+ auto p1 = set.equal_range(3);
+ EXPECT_TRUE(p1.first != p1.second);
+ EXPECT_EQ(*p1.first, 3);
+ ++p1.first;
+ EXPECT_TRUE(p1.first == p1.second);
+
+ // Missing element
+ auto p2 = set.equal_range(100);
+ EXPECT_TRUE(p2.first == p2.second);
+}
+
+TEST(FlatSet, EqualRangeConst) {
+ NumSet tmp;
+ Fill(&tmp, 1, 10);
+
+ const NumSet set = tmp;
+
+ // Existing element
+ auto p1 = set.equal_range(3);
+ EXPECT_TRUE(p1.first != p1.second);
+ EXPECT_EQ(*p1.first, 3);
+ ++p1.first;
+ EXPECT_TRUE(p1.first == p1.second);
+
+ // Missing element
+ auto p2 = set.equal_range(100);
+ EXPECT_TRUE(p2.first == p2.second);
+}
+
+TEST(FlatSet, Prefetch) {
+ NumSet set;
+ Fill(&set, 0, 1000);
+ // Prefetch present and missing keys.
+ for (int i = 0; i < 2000; i++) {
+ set.prefetch_value(i);
+ }
+}
+
+// Non-copyable values should work.
+struct NC {
+ int64 value;
+ NC() : value(-1) {}
+ NC(int64 v) : value(v) {}
+ NC(const NC& x) : value(x.value) {}
+ bool operator==(const NC& x) const { return value == x.value; }
+};
+struct HashNC {
+ size_t operator()(NC x) const { return x.value; }
+};
+
+TEST(FlatSet, NonCopyable) {
+ FlatSet<NC, HashNC> set;
+ for (int i = 0; i < 100; i++) {
+ set.insert(NC(i));
+ }
+ for (int i = 0; i < 100; i++) {
+ EXPECT_EQ(set.count(NC(i)), 1);
+ auto iter = set.find(NC(i));
+ EXPECT_NE(iter, set.end());
+ EXPECT_EQ(*iter, NC(i));
+ }
+ set.erase(NC(10));
+ EXPECT_EQ(set.count(NC(10)), 0);
+}
+
+// Test with heap-allocated objects so that mismanaged constructions
+// or destructions will show up as errors under a sanitizer or
+// heap checker.
+TEST(FlatSet, ConstructDestruct) {
+ FlatSet<string, HashStr> set;
+ string k1 = "the quick brown fox jumped over the lazy dog";
+ string k2 = k1 + k1;
+ string k3 = k1 + k2;
+ set.insert(k1);
+ set.insert(k3);
+ EXPECT_EQ(set.count(k1), 1);
+ EXPECT_EQ(set.count(k2), 0);
+ EXPECT_EQ(set.count(k3), 1);
+
+ set.erase(k3);
+ EXPECT_EQ(set.count(k3), 0);
+
+ set.clear();
+ set.insert(k1);
+ EXPECT_EQ(set.count(k1), 1);
+ EXPECT_EQ(set.count(k3), 0);
+
+ set.reserve(100);
+ EXPECT_EQ(set.count(k1), 1);
+ EXPECT_EQ(set.count(k3), 0);
+}
+
+// Type to use to ensure that custom equality operator is used
+// that ignores extra value.
+struct CustomCmpKey {
+ int64 a;
+ int64 b;
+ CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
+ bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
+};
+struct HashA {
+ size_t operator()(CustomCmpKey x) const { return x.a; }
+};
+struct EqA {
+ // Ignore b fields.
+ bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
+};
+TEST(FlatSet, CustomCmp) {
+ FlatSet<CustomCmpKey, HashA, EqA> set;
+ set.insert(CustomCmpKey(100, 200));
+ EXPECT_EQ(set.count(CustomCmpKey(100, 200)), 1);
+ EXPECT_EQ(set.count(CustomCmpKey(100, 500)), 1); // key.b ignored
+}
+
+// Test unique_ptr handling.
+typedef std::unique_ptr<int> UniqInt;
+static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
+
+struct HashUniq {
+ size_t operator()(const UniqInt& p) const { return *p; }
+};
+struct EqUniq {
+ bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
+};
+typedef FlatSet<UniqInt, HashUniq, EqUniq> UniqSet;
+
+TEST(FlatSet, UniqueSet) {
+ UniqSet set;
+
+ // Fill set
+ const int N = 10;
+ for (int i = 0; i < N; i++) {
+ set.emplace(MakeUniq(i));
+ }
+ EXPECT_EQ(set.size(), N);
+
+ // Lookups
+ for (int i = 0; i < N; i++) {
+ EXPECT_EQ(set.count(MakeUniq(i)), 1);
+ }
+
+ // erase
+ set.erase(MakeUniq(2));
+ EXPECT_EQ(set.count(MakeUniq(2)), 0);
+
+ // clear
+ set.clear();
+ EXPECT_EQ(set.size(), 0);
+}
+
+TEST(FlatSet, UniqueSetIter) {
+ UniqSet set;
+ const int kCount = 10;
+ for (int i = 1; i <= kCount; i++) {
+ set.emplace(MakeUniq(i));
+ }
+ int sum = 0;
+ for (const auto& p : set) {
+ sum += *p;
+ }
+ EXPECT_EQ(sum, (kCount * (kCount + 1)) / 2);
+}
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index 3c71e7d6cc..4e64c90d62 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -42,6 +42,24 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) {
return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
}
+// Convenience Hash functors
+struct HashInt64 {
+ size_t operator()(int64 x) const { return static_cast<size_t>(x); }
+};
+struct HashStr {
+ size_t operator()(const string& s) const {
+ return static_cast<size_t>(Hash64(s));
+ }
+};
+template <typename PTR>
+struct HashPtr {
+ size_t operator()(const PTR p) const {
+ // Hash pointers as integers, but bring more entropy to the lower bits.
+ size_t k = static_cast<size_t>(reinterpret_cast<uintptr_t>(p));
+ return k + (k >> 6);
+ }
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_LIB_HASH_HASH_H_