From a1ec5d62e70994d4d488d827f4e44a9a4165fd36 Mon Sep 17 00:00:00 2001 From: Evan Brown Date: Mon, 19 Dec 2022 11:53:21 -0800 Subject: In sanitizer mode, add generations to swisstable iterators and backing arrays so that we can detect invalid iterator use. PiperOrigin-RevId: 496455788 Change-Id: I83df92828098a3ef1181b4e454f3ac5d3ac7a2f2 --- absl/container/internal/raw_hash_set.h | 260 ++++++++++++++++++++++++++++----- 1 file changed, 223 insertions(+), 37 deletions(-) (limited to 'absl/container/internal/raw_hash_set.h') diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h index fb945c6c..ad4c2cc5 100644 --- a/absl/container/internal/raw_hash_set.h +++ b/absl/container/internal/raw_hash_set.h @@ -186,6 +186,7 @@ #include "absl/base/config.h" #include "absl/base/internal/endian.h" #include "absl/base/internal/prefetch.h" +#include "absl/base/internal/raw_logging.h" #include "absl/base/optimization.h" #include "absl/base/port.h" #include "absl/container/internal/common.h" @@ -219,6 +220,29 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace container_internal { +#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS +#error ABSL_SWISSTABLE_ENABLE_GENERATIONS cannot be directly set +#elif defined(ABSL_HAVE_ADDRESS_SANITIZER) || \ + defined(ABSL_HAVE_MEMORY_SANITIZER) +// When compiled in sanitizer mode, we add generation integers to the backing +// array and iterators. In the backing array, we store the generation between +// the control bytes and the slots. When iterators are dereferenced, we assert +// that the container has not been mutated in a way that could cause iterator +// invalidation since the iterator was initialized. +#define ABSL_SWISSTABLE_ENABLE_GENERATIONS +#endif + +// We use uint8_t so we don't need to worry about padding. +using GenerationType = uint8_t; + +#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS +constexpr bool SwisstableGenerationsEnabled() { return true; } +constexpr size_t NumGenerationBytes() { return sizeof(GenerationType); } +#else +constexpr bool SwisstableGenerationsEnabled() { return false; } +constexpr size_t NumGenerationBytes() { return 0; } +#endif + template void SwapAlloc(AllocType& lhs, AllocType& rhs, std::true_type /* propagate_on_container_swap */) { @@ -451,7 +475,7 @@ static_assert(ctrl_t::kDeleted == static_cast(-2), "ctrl_t::kDeleted must be -2 to make the implementation of " "ConvertSpecialToEmptyAndFullToDeleted efficient"); -ABSL_DLL extern const ctrl_t kEmptyGroup[16]; +ABSL_DLL extern const ctrl_t kEmptyGroup[17]; // Returns a pointer to a control byte group that can be used by empty tables. inline ctrl_t* EmptyGroup() { @@ -460,6 +484,12 @@ inline ctrl_t* EmptyGroup() { return const_cast(kEmptyGroup); } +// Returns a pointer to the generation byte at the end of the empty group, if it +// exists. +inline GenerationType* EmptyGeneration() { + return reinterpret_cast(EmptyGroup() + 16); +} + // Mixes a randomly generated per-process seed with `hash` and `ctrl` to // randomize insertion order within groups. bool ShouldInsertBackwards(size_t hash, const ctrl_t* ctrl); @@ -719,13 +749,116 @@ using Group = GroupAArch64Impl; using Group = GroupPortableImpl; #endif +class CommonFieldsGenerationInfoEnabled { + public: + CommonFieldsGenerationInfoEnabled() = default; + CommonFieldsGenerationInfoEnabled(CommonFieldsGenerationInfoEnabled&& that) + : reserved_growth_(that.reserved_growth_), generation_(that.generation_) { + that.reserved_growth_ = 0; + that.generation_ = EmptyGeneration(); + } + CommonFieldsGenerationInfoEnabled& operator=( + CommonFieldsGenerationInfoEnabled&&) = default; + + void maybe_increment_generation_on_insert() { + if (reserved_growth_ > 0) { + --reserved_growth_; + } else { + ++*generation_; + } + } + void reset_reserved_growth(size_t reservation, size_t size) { + reserved_growth_ = reservation - size; + } + size_t reserved_growth() const { return reserved_growth_; } + void set_reserved_growth(size_t r) { reserved_growth_ = r; } + GenerationType generation() const { return *generation_; } + void set_generation(GenerationType g) { *generation_ = g; } + GenerationType* generation_ptr() const { return generation_; } + void set_generation_ptr(GenerationType* g) { generation_ = g; } + + private: + // The number of insertions remaining that are guaranteed to not rehash due to + // a prior call to reserve. Note: we store reserved growth rather than + // reservation size because calls to erase() decrease size_ but don't decrease + // reserved growth. + // TODO(b/254649633): we can use reserved_growth_ to find more bugs by doing + // extra rehashes in sanitizer mode when reserved_growth_ is 0. We could + // potentially do a rehash with low probability whenever reserved_growth_ is + // zero, but also add a deterministic rehash the first insert after + // reserved_growth_ is zero after a call to reserve. This would detect cases + // of invalid references (as opposed to invalid iterators). + size_t reserved_growth_ = 0; + // Pointer to the generation counter, which is used to validate iterators and + // is stored in the backing array between the control bytes and the slots. + // Note that we can't store the generation inside the container itself and + // keep a pointer to the container in the iterators because iterators must + // remain valid when the container is moved. + // Note: we could derive this pointer from the control pointer, but it makes + // the code more complicated, and there's a benefit in having the sizes of + // raw_hash_set in sanitizer mode and non-sanitizer mode a bit more different, + // which is that tests are less likely to rely on the size remaining the same. + GenerationType* generation_ = EmptyGeneration(); +}; + +class CommonFieldsGenerationInfoDisabled { + public: + CommonFieldsGenerationInfoDisabled() = default; + CommonFieldsGenerationInfoDisabled(CommonFieldsGenerationInfoDisabled&&) = + default; + CommonFieldsGenerationInfoDisabled& operator=( + CommonFieldsGenerationInfoDisabled&&) = default; + + void maybe_increment_generation_on_insert() {} + void reset_reserved_growth(size_t, size_t) {} + size_t reserved_growth() const { return 0; } + void set_reserved_growth(size_t) {} + GenerationType generation() const { return 0; } + void set_generation(GenerationType) {} + GenerationType* generation_ptr() const { return nullptr; } + void set_generation_ptr(GenerationType*) {} +}; + +class HashSetIteratorGenerationInfoEnabled { + public: + HashSetIteratorGenerationInfoEnabled() = default; + explicit HashSetIteratorGenerationInfoEnabled( + const GenerationType* generation_ptr) + : generation_ptr_(generation_ptr), generation_(*generation_ptr) {} + + GenerationType generation() const { return generation_; } + void reset_generation() { generation_ = *generation_ptr_; } + const GenerationType* generation_ptr() const { return generation_ptr_; } + void set_generation_ptr(const GenerationType* ptr) { generation_ptr_ = ptr; } + + private: + const GenerationType* generation_ptr_ = nullptr; + GenerationType generation_ = 0; +}; + +class HashSetIteratorGenerationInfoDisabled { + public: + HashSetIteratorGenerationInfoDisabled() = default; + explicit HashSetIteratorGenerationInfoDisabled(const GenerationType*) {} + + GenerationType generation() const { return 0; } + void reset_generation() {} + const GenerationType* generation_ptr() const { return nullptr; } + void set_generation_ptr(const GenerationType*) {} +}; + +#ifdef ABSL_SWISSTABLE_ENABLE_GENERATIONS +using CommonFieldsGenerationInfo = CommonFieldsGenerationInfoEnabled; +using HashSetIteratorGenerationInfo = HashSetIteratorGenerationInfoEnabled; +#else +using CommonFieldsGenerationInfo = CommonFieldsGenerationInfoDisabled; +using HashSetIteratorGenerationInfo = HashSetIteratorGenerationInfoDisabled; +#endif + // CommonFields hold the fields in raw_hash_set that do not depend // on template parameters. This allows us to conveniently pass all // of this state to helper functions as a single argument. -// -// We make HashtablezInfoHandle a base class to take advantage of -// the empty base-class optimization when sampling is turned off. -class CommonFields : public HashtablezInfoHandle { +class CommonFields : public CommonFieldsGenerationInfo { public: CommonFields() = default; @@ -735,25 +868,34 @@ class CommonFields : public HashtablezInfoHandle { // Movable CommonFields(CommonFields&& that) - : HashtablezInfoHandle( - std::move(static_cast(that))), + : CommonFieldsGenerationInfo( + std::move(static_cast(that))), // Explicitly copying fields into "this" and then resetting "that" // fields generates less code then calling absl::exchange per field. control_(that.control_), slots_(that.slots_), size_(that.size_), capacity_(that.capacity_), - growth_left_(that.growth_left_) { + compressed_tuple_(that.growth_left(), std::move(that.infoz())) { that.control_ = EmptyGroup(); that.slots_ = nullptr; that.size_ = 0; that.capacity_ = 0; - that.growth_left_ = 0; + that.growth_left() = 0; } CommonFields& operator=(CommonFields&&) = default; - HashtablezInfoHandle& infoz() { return *this; } - const HashtablezInfoHandle& infoz() const { return *this; } + // The number of slots we can still fill without needing to rehash. + size_t& growth_left() { return compressed_tuple_.template get<0>(); } + + HashtablezInfoHandle& infoz() { return compressed_tuple_.template get<1>(); } + const HashtablezInfoHandle& infoz() const { + return compressed_tuple_.template get<1>(); + } + + void reset_reserved_growth(size_t reservation) { + CommonFieldsGenerationInfo::reset_reserved_growth(reservation, size_); + } // TODO(b/259599413): Investigate removing some of these fields: // - control/slots can be derived from each other @@ -775,8 +917,10 @@ class CommonFields : public HashtablezInfoHandle { // The total number of available slots. size_t capacity_ = 0; - // The number of slots we can still fill without needing to rehash. - size_t growth_left_ = 0; + // Bundle together growth_left and HashtablezInfoHandle to ensure EBO for + // HashtablezInfoHandle when sampling is turned off. + absl::container_internal::CompressedTuple + compressed_tuple_{0u, HashtablezInfoHandle{}}; }; // Returns he number of "cloned control bytes". @@ -859,19 +1003,26 @@ size_t SelectBucketCountForIterRange(InputIter first, InputIter last, return 0; } -#define ABSL_INTERNAL_ASSERT_IS_FULL(ctrl, operation) \ - do { \ - ABSL_HARDENING_ASSERT( \ - (ctrl != nullptr) && operation \ - " called on invalid iterator. The iterator might be an end() " \ - "iterator or may have been default constructed."); \ - ABSL_HARDENING_ASSERT( \ - (IsFull(*ctrl)) && operation \ - " called on invalid iterator. The element might have been erased or " \ - "the table might have rehashed."); \ +#define ABSL_INTERNAL_ASSERT_IS_FULL(ctrl, generation, generation_ptr, \ + operation) \ + do { \ + ABSL_HARDENING_ASSERT( \ + (ctrl != nullptr) && operation \ + " called on invalid iterator. The iterator might be an end() " \ + "iterator or may have been default constructed."); \ + if (SwisstableGenerationsEnabled() && generation != *generation_ptr) \ + ABSL_INTERNAL_LOG(FATAL, operation \ + " called on invalidated iterator. The table could " \ + "have rehashed since this iterator was initialized."); \ + ABSL_HARDENING_ASSERT( \ + (IsFull(*ctrl)) && operation \ + " called on invalid iterator. The element might have been erased or " \ + "the table might have rehashed."); \ } while (0) // Note that for comparisons, null/end iterators are valid. +// TODO(b/254649633): when generations are enabled, detect cases of invalid +// iterators being compared. inline void AssertIsValidForComparison(const ctrl_t* ctrl) { ABSL_HARDENING_ASSERT((ctrl == nullptr || IsFull(*ctrl)) && "Invalid iterator comparison. The element might have " @@ -900,6 +1051,9 @@ inline bool AreItersFromSameContainer(const ctrl_t* ctrl_a, // Asserts that two iterators come from the same container. // Note: we take slots by reference so that it's not UB if they're uninitialized // as long as we don't read them (when ctrl is null). +// TODO(b/254649633): when generations are enabled, we can detect more cases of +// different containers by comparing the pointers to the generations - this +// can cover cases of end iterators that we would otherwise miss. inline void AssertSameContainer(const ctrl_t* ctrl_a, const ctrl_t* ctrl_b, const void* const& slot_a, const void* const& slot_b) { @@ -976,7 +1130,7 @@ extern template FindInfo find_first_non_full(const CommonFields&, size_t); FindInfo find_first_non_full_outofline(const CommonFields&, size_t); inline void ResetGrowthLeft(CommonFields& common) { - common.growth_left_ = CapacityToGrowth(common.capacity_) - common.size_; + common.growth_left() = CapacityToGrowth(common.capacity_) - common.size_; } // Sets `ctrl` to `{kEmpty, kSentinel, ..., kEmpty}`, marking the entire @@ -1018,12 +1172,21 @@ inline void SetCtrl(const CommonFields& common, size_t i, h2_t h, SetCtrl(common, i, static_cast(h), slot_size); } +// Given the capacity of a table, computes the offset (from the start of the +// backing allocation) of the generation counter (if it exists). +inline size_t GenerationOffset(size_t capacity) { + assert(IsValidCapacity(capacity)); + const size_t num_control_bytes = capacity + 1 + NumClonedBytes(); + return num_control_bytes; +} + // Given the capacity of a table, computes the offset (from the start of the // backing allocation) at which the slots begin. inline size_t SlotOffset(size_t capacity, size_t slot_align) { assert(IsValidCapacity(capacity)); const size_t num_control_bytes = capacity + 1 + NumClonedBytes(); - return (num_control_bytes + slot_align - 1) & (~slot_align + 1); + return (num_control_bytes + NumGenerationBytes() + slot_align - 1) & + (~slot_align + 1); } // Given the capacity of a table, computes the total size of the backing @@ -1048,6 +1211,10 @@ ABSL_ATTRIBUTE_NOINLINE void InitializeSlots(CommonFields& c, Alloc alloc) { const size_t cap = c.capacity_; char* mem = static_cast( Allocate(&alloc, AllocSize(cap, SizeOfSlot, AlignOfSlot))); + const GenerationType old_generation = c.generation(); + c.set_generation_ptr( + reinterpret_cast(mem + GenerationOffset(cap))); + c.set_generation(old_generation + 1); c.control_ = reinterpret_cast(mem); c.slots_ = mem + SlotOffset(cap, AlignOfSlot); ResetCtrl(c, SizeOfSlot); @@ -1213,7 +1380,7 @@ class raw_hash_set { static_assert(std::is_same::value, "Allocators with custom pointer types are not supported"); - class iterator { + class iterator : private HashSetIteratorGenerationInfo { friend class raw_hash_set; public: @@ -1229,19 +1396,22 @@ class raw_hash_set { // PRECONDITION: not an end() iterator. reference operator*() const { - ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator*()"); + ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(), + "operator*()"); return PolicyTraits::element(slot_); } // PRECONDITION: not an end() iterator. pointer operator->() const { - ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator->"); + ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(), + "operator->"); return &operator*(); } // PRECONDITION: not an end() iterator. iterator& operator++() { - ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, "operator++"); + ABSL_INTERNAL_ASSERT_IS_FULL(ctrl_, generation(), generation_ptr(), + "operator++"); ++ctrl_; ++slot_; skip_empty_or_deleted(); @@ -1265,7 +1435,11 @@ class raw_hash_set { } private: - iterator(ctrl_t* ctrl, slot_type* slot) : ctrl_(ctrl), slot_(slot) { + iterator(ctrl_t* ctrl, slot_type* slot, + const GenerationType* generation_ptr) + : HashSetIteratorGenerationInfo(generation_ptr), + ctrl_(ctrl), + slot_(slot) { // This assumption helps the compiler know that any non-end iterator is // not equal to any end iterator. ABSL_ASSUME(ctrl != nullptr); @@ -1323,8 +1497,10 @@ class raw_hash_set { } private: - const_iterator(const ctrl_t* ctrl, const slot_type* slot) - : inner_(const_cast(ctrl), const_cast(slot)) {} + const_iterator(const ctrl_t* ctrl, const slot_type* slot, + const GenerationType* gen) + : inner_(const_cast(ctrl), const_cast(slot), gen) { + } iterator inner_; }; @@ -1455,6 +1631,7 @@ class raw_hash_set { auto target = find_first_non_full_outofline(common(), hash); SetCtrl(common(), target.offset, H2(hash), sizeof(slot_type)); emplace_at(target.offset, v); + common().maybe_increment_generation_on_insert(); infoz().RecordInsert(hash, target.probe_length); } common().size_ = that.size(); @@ -1553,6 +1730,7 @@ class raw_hash_set { ClearBackingArray(common(), GetPolicyFunctions(), /*reuse=*/cap < 128); } + common().set_reserved_growth(0); } inline void destroy_slots() { @@ -1793,7 +1971,8 @@ class raw_hash_set { // This overload is necessary because otherwise erase(const K&) would be // a better match if non-const iterator is passed as an argument. void erase(iterator it) { - ABSL_INTERNAL_ASSERT_IS_FULL(it.ctrl_, "erase()"); + ABSL_INTERNAL_ASSERT_IS_FULL(it.ctrl_, it.generation(), it.generation_ptr(), + "erase()"); PolicyTraits::destroy(&alloc_ref(), it.slot_); erase_meta_only(it); } @@ -1827,7 +2006,9 @@ class raw_hash_set { } node_type extract(const_iterator position) { - ABSL_INTERNAL_ASSERT_IS_FULL(position.inner_.ctrl_, "extract()"); + ABSL_INTERNAL_ASSERT_IS_FULL(position.inner_.ctrl_, + position.inner_.generation(), + position.inner_.generation_ptr(), "extract()"); auto node = CommonAccess::Transfer(alloc_ref(), position.inner_.slot_); erase_meta_only(position); @@ -1883,6 +2064,7 @@ class raw_hash_set { // This is after resize, to ensure that we have completed the allocation // and have potentially sampled the hashtable. infoz().RecordReservation(n); + common().reset_reserved_growth(n); } } @@ -2268,6 +2450,7 @@ class raw_hash_set { ++common().size_; growth_left() -= IsEmpty(control()[target.offset]); SetCtrl(common(), target.offset, H2(hash), sizeof(slot_type)); + common().maybe_increment_generation_on_insert(); infoz().RecordInsert(hash, target.probe_length); return target.offset; } @@ -2290,9 +2473,11 @@ class raw_hash_set { "constructed value does not match the lookup key"); } - iterator iterator_at(size_t i) { return {control() + i, slot_array() + i}; } + iterator iterator_at(size_t i) { + return {control() + i, slot_array() + i, common().generation_ptr()}; + } const_iterator iterator_at(size_t i) const { - return {control() + i, slot_array() + i}; + return {control() + i, slot_array() + i, common().generation_ptr()}; } private: @@ -2308,7 +2493,7 @@ class raw_hash_set { // side-effect. // // See `CapacityToGrowth()`. - size_t& growth_left() { return common().growth_left_; } + size_t& growth_left() { return common().growth_left(); } // Prefetch the heap-allocated memory region to resolve potential TLB misses. // This is intended to overlap with execution of calculating the hash for a @@ -2460,6 +2645,7 @@ struct HashtableDebugAccess> { ABSL_NAMESPACE_END } // namespace absl +#undef ABSL_SWISSTABLE_ENABLE_GENERATIONS #undef ABSL_INTERNAL_ASSERT_IS_FULL #endif // ABSL_CONTAINER_INTERNAL_RAW_HASH_SET_H_ -- cgit v1.2.3