summaryrefslogtreecommitdiff
path: root/absl/container/internal/raw_hash_set.h
diff options
context:
space:
mode:
authorGravatar Evan Brown <ezb@google.com>2022-12-19 11:53:21 -0800
committerGravatar Copybara-Service <copybara-worker@google.com>2022-12-19 11:54:08 -0800
commita1ec5d62e70994d4d488d827f4e44a9a4165fd36 (patch)
tree1c8e0e5b8d8a351f50005da33254a6531b98882e /absl/container/internal/raw_hash_set.h
parentdbc61b490c5c259df33af59f9922a7224341397b (diff)
In sanitizer mode, add generations to swisstable iterators and backing arrays so that we can detect invalid iterator use.
PiperOrigin-RevId: 496455788 Change-Id: I83df92828098a3ef1181b4e454f3ac5d3ac7a2f2
Diffstat (limited to 'absl/container/internal/raw_hash_set.h')
-rw-r--r--absl/container/internal/raw_hash_set.h260
1 files changed, 223 insertions, 37 deletions
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index fb945c6c..ad4c2cc5 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -186,6 +186,7 @@
#include "absl/base/config.h"
#include "absl/base/internal/endian.h"
#include "absl/base/internal/prefetch.h"
+#include "absl/base/internal/raw_logging.h"
#include "absl/base/optimization.h"
#include "absl/base/port.h"
#include "absl/container/internal/common.h"
@@ -219,6 +220,29 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace container_internal {
+#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS
+#error ABSL_SWISSTABLE_ENABLE_GENERATIONS cannot be directly set
+#elif defined(ABSL_HAVE_ADDRESS_SANITIZER) || \
+ defined(ABSL_HAVE_MEMORY_SANITIZER)
+// When compiled in sanitizer mode, we add generation integers to the backing
+// array and iterators. In the backing array, we store the generation between
+// the control bytes and the slots. When iterators are dereferenced, we assert
+// that the container has not been mutated in a way that could cause iterator
+// invalidation since the iterator was initialized.
+#define ABSL_SWISSTABLE_ENABLE_GENERATIONS
+#endif
+
+// We use uint8_t so we don't need to worry about padding.
+using GenerationType = uint8_t;
+
+#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS
+constexpr bool SwisstableGenerationsEnabled() { return true; }
+constexpr size_t NumGenerationBytes() { return sizeof(GenerationType); }
+#else
+constexpr bool SwisstableGenerationsEnabled() { return false; }
+constexpr size_t NumGenerationBytes() { return 0; }
+#endif
+
template <typename AllocType>
void SwapAlloc(AllocType& lhs, AllocType& rhs,
std::true_type /* propagate_on_container_swap */) {
@@ -451,7 +475,7 @@ static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2),
"ctrl_t::kDeleted must be -2 to make the implementation of "
"ConvertSpecialToEmptyAndFullToDeleted efficient");
-ABSL_DLL extern const ctrl_t kEmptyGroup[16];
+ABSL_DLL extern const ctrl_t kEmptyGroup[17];
// Returns a pointer to a control byte group that can be used by empty tables.
inline ctrl_t* EmptyGroup() {
@@ -460,6 +484,12 @@ inline ctrl_t* EmptyGroup() {
return const_cast<ctrl_t*>(kEmptyGroup);
}
+// Returns a pointer to the generation byte at the end of the empty group, if it
+// exists.
+inline GenerationType* EmptyGeneration() {
+ return reinterpret_cast<GenerationType*>(EmptyGroup() + 16);
+}
+
// Mixes a randomly generated per-process seed with `hash` and `ctrl` to
// randomize insertion order within groups.
bool ShouldInsertBackwards(size_t hash, const ctrl_t* ctrl);
@@ -719,13 +749,116 @@ using Group = GroupAArch64Impl;
using Group = GroupPortableImpl;
#endif
+class CommonFieldsGenerationInfoEnabled {
+ public:
+ CommonFieldsGenerationInfoEnabled() = default;
+ CommonFieldsGenerationInfoEnabled(CommonFieldsGenerationInfoEnabled&& that)
+ : reserved_growth_(that.reserved_growth_), generation_(that.generation_) {
+ that.reserved_growth_ = 0;
+ that.generation_ = EmptyGeneration();
+ }
+ CommonFieldsGenerationInfoEnabled& operator=(
+ CommonFieldsGenerationInfoEnabled&&) = default;
+
+ void maybe_increment_generation_on_insert() {
+ if (reserved_growth_ > 0) {
+ --reserved_growth_;
+ } else {
+ ++*generation_;
+ }
+ }
+ void reset_reserved_growth(size_t reservation, size_t size) {
+ reserved_growth_ = reservation - size;
+ }
+ size_t reserved_growth() const { return reserved_growth_; }
+ void set_reserved_growth(size_t r) { reserved_growth_ = r; }
+ GenerationType generation() const { return *generation_; }
+ void set_generation(GenerationType g) { *generation_ = g; }
+ GenerationType* generation_ptr() const { return generation_; }
+ void set_generation_ptr(GenerationType* g) { generation_ = g; }
+
+ private:
+ // The number of insertions remaining that are guaranteed to not rehash due to
+ // a prior call to reserve. Note: we store reserved growth rather than
+ // reservation size because calls to erase() decrease size_ but don't decrease
+ // reserved growth.
+ // TODO(b/254649633): we can use reserved_growth_ to find more bugs by doing
+ // extra rehashes in sanitizer mode when reserved_growth_ is 0. We could
+ // potentially do a rehash with low probability whenever reserved_growth_ is
+ // zero, but also add a deterministic rehash the first insert after
+ // reserved_growth_ is zero after a call to reserve. This would detect cases
+ // of invalid references (as opposed to invalid iterators).
+ size_t reserved_growth_ = 0;
+ // Pointer to the generation counter, which is used to validate iterators and
+ // is stored in the backing array between the control bytes and the slots.
+ // Note that we can't store the generation inside the container itself and
+ // keep a pointer to the container in the iterators because iterators must
+ // remain valid when the container is moved.
+ // Note: we could derive this pointer from the control pointer, but it makes
+ // the code more complicated, and there's a benefit in having the sizes of
+ // raw_hash_set in sanitizer mode and non-sanitizer mode a bit more different,
+ // which is that tests are less likely to rely on the size remaining the same.
+ GenerationType* generation_ = EmptyGeneration();
+};
+
+class CommonFieldsGenerationInfoDisabled {
+ public:
+ CommonFieldsGenerationInfoDisabled() = default;
+ CommonFieldsGenerationInfoDisabled(CommonFieldsGenerationInfoDisabled&&) =
+ default;
+ CommonFieldsGenerationInfoDisabled& operator=(
+ CommonFieldsGenerationInfoDisabled&&) = default;
+
+ void maybe_increment_generation_on_insert() {}
+ void reset_reserved_growth(size_t, size_t) {}
+ size_t reserved_growth() const { return 0; }
+ void set_reserved_growth(size_t) {}
+ GenerationType generation() const { return 0; }
+ void set_generation(GenerationType) {}
+ GenerationType* generation_ptr() const { return nullptr; }
+ void set_generation_ptr(GenerationType*) {}
+};
+
+class HashSetIteratorGenerationInfoEnabled {
+ public:
+ HashSetIteratorGenerationInfoEnabled() = default;
+ explicit HashSetIteratorGenerationInfoEnabled(
+ const GenerationType* generation_ptr)
+ : generation_ptr_(generation_ptr), generation_(*generation_ptr) {}
+
+ GenerationType generation() const { return generation_; }
+ void reset_generation() { generation_ = *generation_ptr_; }
+ const GenerationType* generation_ptr() const { return generation_ptr_; }
+ void set_generation_ptr(const GenerationType* ptr) { generation_ptr_ = ptr; }
+
+ private:
+ const GenerationType* generation_ptr_ = nullptr;
+ GenerationType generation_ = 0;
+};
+
+class HashSetIteratorGenerationInfoDisabled {
+ public:
+ HashSetIteratorGenerationInfoDisabled() = default;
+ explicit HashSetIteratorGenerationInfoDisabled(const GenerationType*) {}
+
+ GenerationType generation() const { return 0; }
+ void reset_generation() {}
+ const GenerationType* generation_ptr() const { return nullptr; }
+ void set_generation_ptr(const GenerationType*) {}
+};
+
+#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS
+using CommonFieldsGenerationInfo = CommonFieldsGenerationInfoEnabled;
+using HashSetIteratorGenerationInfo = HashSetIteratorGenerationInfoEnabled;
+#else
+using CommonFieldsGenerationInfo = CommonFieldsGenerationInfoDisabled;
+using HashSetIteratorGenerationInfo = HashSetIteratorGenerationInfoDisabled;
+#endif
+
// CommonFields hold the fields in raw_hash_set that do not depend
// on template parameters. This allows us to conveniently pass all
// of this state to helper functions as a single argument.
-//
-// We make HashtablezInfoHandle a base class to take advantage of
-// the empty base-class optimization when sampling is turned off.
-class CommonFields : public HashtablezInfoHandle {
+class CommonFields : public CommonFieldsGenerationInfo {
public:
CommonFields() = default;
@@ -735,25 +868,34 @@ class CommonFields : public HashtablezInfoHandle {
// Movable
CommonFields(CommonFields&& that)
- : HashtablezInfoHandle(
- std::move(static_cast<HashtablezInfoHandle&&>(that))),
+ : CommonFieldsGenerationInfo(
+ std::move(static_cast<CommonFieldsGenerationInfo&&>(that))),
// Explicitly copying fields into "this" and then resetting "that"
// fields generates less code then calling absl::exchange per field.
control_(that.control_),
slots_(that.slots_),
size_(that.size_),
capacity_(that.capacity_),
- growth_left_(that.growth_left_) {
+ compressed_tuple_(that.growth_left(), std::move(that.infoz())) {
that.control_ = EmptyGroup();
that.slots_ = nullptr;
that.size_ = 0;
that.capacity_ = 0;
- that.growth_left_ = 0;
+ that.growth_left() = 0;
}
CommonFields& operator=(CommonFields&&) = default;
- HashtablezInfoHandle& infoz() { return *this; }
- const HashtablezInfoHandle& infoz() const { return *this; }
+ // The number of slots we can still fill without needing to rehash.
+ size_t& growth_left() { return compressed_tuple_.template get<0>(); }
+
+ HashtablezInfoHandle& infoz() { return compressed_tuple_.template get<1>(); }
+ const HashtablezInfoHandle& infoz() const {
+ return compressed_tuple_.template get<1>();
+ }
+
+ void reset_reserved_growth(size_t reservation) {
+ CommonFieldsGenerationInfo::reset_reserved_growth(reservation, size_);
+ }
// TODO(b/259599413): Investigate removing some of these fields:
// - control/slots can be derived from each other
@@ -775,8 +917,10 @@ class CommonFields : public HashtablezInfoHandle {
// The total number of available slots.
size_t capacity_ = 0;
- // The number of slots we can still fill without needing to rehash.
- size_t growth_left_ = 0;
+ // Bundle together growth_left and HashtablezInfoHandle to ensure EBO for
+ // HashtablezInfoHandle when sampling is turned off.
+ absl::container_internal::CompressedTuple<size_t, HashtablezInfoHandle>
+ compressed_tuple_{0u, HashtablezInfoHandle{}};
};
// Returns he number of "cloned control bytes".
@@ -859,19 +1003,26 @@ size_t SelectBucketCountForIterRange(InputIter first, InputIter last,
return 0;
}
-#define ABSL_INTERNAL_ASSERT_IS_FULL(ctrl, operation) \
- do { \
- ABSL_HARDENING_ASSERT( \
- (ctrl != nullptr) && operation \
- " called on invalid iterator. The iterator might be an end() " \
- "iterator or may have been default constructed."); \
- ABSL_HARDENING_ASSERT( \
- (IsFull(*ctrl)) && operation \
- " called on invalid iterator. The element might have been erased or " \
- "the table might have rehashed."); \
+#define ABSL_INTERNAL_ASSERT_IS_FULL(ctrl, generation, generation_ptr, \
+ operation) \
+ do { \
+ ABSL_HARDENING_ASSERT( \
+ (ctrl != nullptr) && operation \
+ " called on invalid iterator. The iterator might be an end() " \
+ "iterator or may have been default constructed."); \
+ if (SwisstableGenerationsEnabled() && generation != *generation_ptr) \
+ ABSL_INTERNAL_LOG(FATAL, operation \
+ " called on invalidated iterator. The table could " \
+ "have rehashed since this iterator was initialized."); \
+ ABSL_HARDENING_ASSERT( \
+ (IsFull(*ctrl)) && operation \
+ " called on invalid iterator. The element might have been erased or " \
+ "the table might have rehashed."); \
} while (0)
// Note that for comparisons, null/end iterators are valid.
+// TODO(b/254649633): when generations are enabled, detect cases of invalid
+// iterators being compared.
inline void AssertIsValidForComparison(const ctrl_t* ctrl) {
ABSL_HARDENING_ASSERT((ctrl == nullptr || IsFull(*ctrl)) &&
"Invalid iterator comparison. The element might have "
@@ -900,6 +1051,9 @@ inline bool AreItersFromSameContainer(const ctrl_t* ctrl_a,
// Asserts that two iterators come from the same container.
// Note: we take slots by reference so that it's not UB if they're uninitialized
// as long as we don't read them (when ctrl is null).
+// TODO(b/254649633): when generations are enabled, we can detect more cases of
+// different containers by comparing the pointers to the generations - this
+// can cover cases of end iterators that we would otherwise miss.
inline void AssertSameContainer(const ctrl_t* ctrl_a, const ctrl_t* ctrl_b,
const void* const& slot_a,
const void* const& slot_b) {
@@ -976,7 +1130,7 @@ extern template FindInfo find_first_non_full(const CommonFields&, size_t);
FindInfo find_first_non_full_outofline(const CommonFields&, size_t);
inline void ResetGrowthLeft(CommonFields& common) {
- common.growth_left_ = CapacityToGrowth(common.capacity_) - common.size_;
+ common.growth_left() = CapacityToGrowth(common.capacity_) - common.size_;
}
// Sets `ctrl` to `{kEmpty, kSentinel, ..., kEmpty}`, marking the entire
@@ -1019,11 +1173,20 @@ inline void SetCtrl(const CommonFields& common, size_t i, h2_t h,
}
// Given the capacity of a table, computes the offset (from the start of the
+// backing allocation) of the generation counter (if it exists).
+inline size_t GenerationOffset(size_t capacity) {
+ assert(IsValidCapacity(capacity));
+ const size_t num_control_bytes = capacity + 1 + NumClonedBytes();
+ return num_control_bytes;
+}
+
+// Given the capacity of a table, computes the offset (from the start of the
// backing allocation) at which the slots begin.
inline size_t SlotOffset(size_t capacity, size_t slot_align) {
assert(IsValidCapacity(capacity));
const size_t num_control_bytes = capacity + 1 + NumClonedBytes();
- return (num_control_bytes + slot_align - 1) & (~slot_align + 1);
+ return (num_control_bytes + NumGenerationBytes() + slot_align - 1) &
+ (~slot_align + 1);
}
// Given the capacity of a table, computes the total size of the backing
@@ -1048,6 +1211,10 @@ ABSL_ATTRIBUTE_NOINLINE void InitializeSlots(CommonFields& c, Alloc alloc) {
const size_t cap = c.capacity_;
char* mem = static_cast<char*>(
Allocate<AlignOfSlot>(&alloc, AllocSize(cap, SizeOfSlot, AlignOfSlot)));
+ const GenerationType old_generation = c.generation();
+ c.set_generation_ptr(
+ reinterpret_cast<GenerationType*>(mem + GenerationOffset(cap)));
+ c.set_generation(old_generation + 1);
c.control_ = reinterpret_cast<ctrl_t*>(mem);
c.slots_ = mem + SlotOffset(cap, AlignOfSlot);
ResetCtrl(c, SizeOfSlot);
@@ -1213,7 +1380,7 @@ class raw_hash_set {
static_assert(std::is_same<const_pointer, const value_type*>::value,
"Allocators with custom pointer types are not supported");
- class iterator {
+ class iterator : private HashSetIteratorGenerationInfo {
friend class raw_hash_set;
public:
@@ -1229,19 +1396,22 @@ class raw_hash_set {
// PRECONDITION: not an end() iterator.
reference operator*() const {
- ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator*()");
+ ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(),
+ "operator*()");
return PolicyTraits::element(slot_);
}
// PRECONDITION: not an end() iterator.
pointer operator->() const {
- ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator->");
+ ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(),
+ "operator->");
return &operator*();
}
// PRECONDITION: not an end() iterator.
iterator& operator++() {
- ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator++");
+ ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(),
+ "operator++");
++ctrl_;
++slot_;
skip_empty_or_deleted();
@@ -1265,7 +1435,11 @@ class raw_hash_set {
}
private:
- iterator(ctrl_t* ctrl, slot_type* slot) : ctrl_(ctrl), slot_(slot) {
+ iterator(ctrl_t* ctrl, slot_type* slot,
+ const GenerationType* generation_ptr)
+ : HashSetIteratorGenerationInfo(generation_ptr),
+ ctrl_(ctrl),
+ slot_(slot) {
// This assumption helps the compiler know that any non-end iterator is
// not equal to any end iterator.
ABSL_ASSUME(ctrl != nullptr);
@@ -1323,8 +1497,10 @@ class raw_hash_set {
}
private:
- const_iterator(const ctrl_t* ctrl, const slot_type* slot)
- : inner_(const_cast<ctrl_t*>(ctrl), const_cast<slot_type*>(slot)) {}
+ const_iterator(const ctrl_t* ctrl, const slot_type* slot,
+ const GenerationType* gen)
+ : inner_(const_cast<ctrl_t*>(ctrl), const_cast<slot_type*>(slot), gen) {
+ }
iterator inner_;
};
@@ -1455,6 +1631,7 @@ class raw_hash_set {
auto target = find_first_non_full_outofline(common(), hash);
SetCtrl(common(), target.offset, H2(hash), sizeof(slot_type));
emplace_at(target.offset, v);
+ common().maybe_increment_generation_on_insert();
infoz().RecordInsert(hash, target.probe_length);
}
common().size_ = that.size();
@@ -1553,6 +1730,7 @@ class raw_hash_set {
ClearBackingArray(common(), GetPolicyFunctions(),
/*reuse=*/cap < 128);
}
+ common().set_reserved_growth(0);
}
inline void destroy_slots() {
@@ -1793,7 +1971,8 @@ class raw_hash_set {
// This overload is necessary because otherwise erase<K>(const K&) would be
// a better match if non-const iterator is passed as an argument.
void erase(iterator it) {
- ABSL_INTERNAL_ASSERT_IS_FULL(it.ctrl_, "erase()");
+ ABSL_INTERNAL_ASSERT_IS_FULL(it.ctrl_, it.generation(), it.generation_ptr(),
+ "erase()");
PolicyTraits::destroy(&alloc_ref(), it.slot_);
erase_meta_only(it);
}
@@ -1827,7 +2006,9 @@ class raw_hash_set {
}
node_type extract(const_iterator position) {
- ABSL_INTERNAL_ASSERT_IS_FULL(position.inner_.ctrl_, "extract()");
+ ABSL_INTERNAL_ASSERT_IS_FULL(position.inner_.ctrl_,
+ position.inner_.generation(),
+ position.inner_.generation_ptr(), "extract()");
auto node =
CommonAccess::Transfer<node_type>(alloc_ref(), position.inner_.slot_);
erase_meta_only(position);
@@ -1883,6 +2064,7 @@ class raw_hash_set {
// This is after resize, to ensure that we have completed the allocation
// and have potentially sampled the hashtable.
infoz().RecordReservation(n);
+ common().reset_reserved_growth(n);
}
}
@@ -2268,6 +2450,7 @@ class raw_hash_set {
++common().size_;
growth_left() -= IsEmpty(control()[target.offset]);
SetCtrl(common(), target.offset, H2(hash), sizeof(slot_type));
+ common().maybe_increment_generation_on_insert();
infoz().RecordInsert(hash, target.probe_length);
return target.offset;
}
@@ -2290,9 +2473,11 @@ class raw_hash_set {
"constructed value does not match the lookup key");
}
- iterator iterator_at(size_t i) { return {control() + i, slot_array() + i}; }
+ iterator iterator_at(size_t i) {
+ return {control() + i, slot_array() + i, common().generation_ptr()};
+ }
const_iterator iterator_at(size_t i) const {
- return {control() + i, slot_array() + i};
+ return {control() + i, slot_array() + i, common().generation_ptr()};
}
private:
@@ -2308,7 +2493,7 @@ class raw_hash_set {
// side-effect.
//
// See `CapacityToGrowth()`.
- size_t& growth_left() { return common().growth_left_; }
+ size_t& growth_left() { return common().growth_left(); }
// Prefetch the heap-allocated memory region to resolve potential TLB misses.
// This is intended to overlap with execution of calculating the hash for a
@@ -2460,6 +2645,7 @@ struct HashtableDebugAccess<Set, absl::void_t<typename Set::raw_hash_set>> {
ABSL_NAMESPACE_END
} // namespace absl
+#undef ABSL_SWISSTABLE_ENABLE_GENERATIONS
#undef ABSL_INTERNAL_ASSERT_IS_FULL
#endif // ABSL_CONTAINER_INTERNAL_RAW_HASH_SET_H_