diff options
author | Abseil Team <absl-team@google.com> | 2022-06-06 09:28:31 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-06-06 09:29:27 -0700 |
commit | 6481443560a92d0a3a55a31807de0cd712cd4f88 (patch) | |
tree | b13d0a400f72cc4d0acc3a35f2ff73b2499a127f /absl | |
parent | 48419595d31609762985a6b08be504ebe6d593e7 (diff) |
Optimize SwissMap for ARM by 3-8% for all operations
https://pastebin.com/CmnzwUFN
The key idea is to avoid using 16 byte NEON and use 8 byte NEON which has lower latency for BitMask::Match. Even though 16 byte NEON achieves higher throughput, in SwissMap it's very important to catch these Matches with low latency as probing on average happens at most once.
I also introduced NonIterableMask as ARM has really great cbnz instructions and additional AND on scalar mask had 1 extra latency cycle
PiperOrigin-RevId: 453216147
Change-Id: I842c50d323954f8383ae156491232ced55aacb78
Diffstat (limited to 'absl')
-rw-r--r-- | absl/base/config.h | 9 | ||||
-rw-r--r-- | absl/container/internal/raw_hash_set.h | 227 | ||||
-rw-r--r-- | absl/container/internal/raw_hash_set_benchmark.cc | 14 | ||||
-rw-r--r-- | absl/container/internal/raw_hash_set_test.cc | 24 |
4 files changed, 176 insertions, 98 deletions
diff --git a/absl/base/config.h b/absl/base/config.h index 4223629e..802529fc 100644 --- a/absl/base/config.h +++ b/absl/base/config.h @@ -898,4 +898,13 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' || #define ABSL_INTERNAL_HAVE_ARM_ACLE 1 #endif +// ABSL_INTERNAL_HAVE_ARM_NEON is used for compile-time detection of NEON (ARM +// SIMD). +#ifdef ABSL_INTERNAL_HAVE_ARM_NEON +#error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set +#elif defined(__ARM_NEON) +#define ABSL_INTERNAL_HAVE_ARM_NEON 1 +#endif + + #endif // ABSL_BASE_CONFIG_H_ diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h index d503bc00..2756ce1b 100644 --- a/absl/container/internal/raw_hash_set.h +++ b/absl/container/internal/raw_hash_set.h @@ -184,6 +184,14 @@ #include <intrin.h> #endif +#ifdef __ARM_NEON +#include <arm_neon.h> +#endif + +#ifdef __ARM_ACLE +#include <arm_acle.h> +#endif + #include <algorithm> #include <cmath> #include <cstdint> @@ -211,10 +219,6 @@ #include "absl/numeric/bits.h" #include "absl/utility/utility.h" -#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE -#include <arm_acle.h> -#endif - namespace absl { ABSL_NAMESPACE_BEGIN namespace container_internal { @@ -323,36 +327,15 @@ uint32_t TrailingZeros(T x) { // controlled by `SignificantBits` and `Shift`. `SignificantBits` is the number // of abstract bits in the bitset, while `Shift` is the log-base-two of the // width of an abstract bit in the representation. -// -// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just -// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When -// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as -// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask. -// -// For example: -// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2 -// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3 +// This mask provides operations for any number of real bits set in an abstract +// bit. To add iteration on top of that, implementation must guarantee no more +// than one real bit is set in an abstract bit. template <class T, int SignificantBits, int Shift = 0> -class BitMask { - static_assert(std::is_unsigned<T>::value, ""); - static_assert(Shift == 0 || Shift == 3, ""); - +class NonIterableBitMask { public: - // BitMask is an iterator over the indices of its abstract bits. - using value_type = int; - using iterator = BitMask; - using const_iterator = BitMask; - - explicit BitMask(T mask) : mask_(mask) {} - BitMask& operator++() { - mask_ &= (mask_ - 1); - return *this; - } - explicit operator bool() const { return mask_ != 0; } - uint32_t operator*() const { return LowestBitSet(); } + explicit NonIterableBitMask(T mask) : mask_(mask) {} - BitMask begin() const { return *this; } - BitMask end() const { return BitMask(0); } + explicit operator bool() const { return this->mask_ != 0; } // Returns the index of the lowest *abstract* bit set in `self`. uint32_t LowestBitSet() const { @@ -376,6 +359,42 @@ class BitMask { return static_cast<uint32_t>(countl_zero(mask_ << extra_bits)) >> Shift; } + T mask_; +}; + +// Mask that can be iterable +// +// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just +// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When +// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as +// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask. +// +// For example: +// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2 +// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3 +template <class T, int SignificantBits, int Shift = 0> +class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> { + using Base = NonIterableBitMask<T, SignificantBits, Shift>; + static_assert(std::is_unsigned<T>::value, ""); + static_assert(Shift == 0 || Shift == 3, ""); + + public: + explicit BitMask(T mask) : Base(mask) {} + // BitMask is an iterator over the indices of its abstract bits. + using value_type = int; + using iterator = BitMask; + using const_iterator = BitMask; + + BitMask& operator++() { + this->mask_ &= (this->mask_ - 1); + return *this; + } + + uint32_t operator*() const { return Base::LowestBitSet(); } + + BitMask begin() const { return *this; } + BitMask end() const { return BitMask(0); } + private: friend bool operator==(const BitMask& a, const BitMask& b) { return a.mask_ == b.mask_; @@ -383,8 +402,6 @@ class BitMask { friend bool operator!=(const BitMask& a, const BitMask& b) { return a.mask_ != b.mask_; } - - T mask_; }; using h2_t = uint8_t; @@ -433,7 +450,7 @@ static_assert( static_cast<int8_t>(ctrl_t::kSentinel) & 0x7F) != 0, "ctrl_t::kEmpty and ctrl_t::kDeleted must share an unset bit that is not " "shared by ctrl_t::kSentinel to make the scalar test for " - "MatchEmptyOrDeleted() efficient"); + "MaskEmptyOrDeleted() efficient"); static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2), "ctrl_t::kDeleted must be -2 to make the implementation of " "ConvertSpecialToEmptyAndFullToDeleted efficient"); @@ -538,20 +555,22 @@ struct GroupSse2Impl { } // Returns a bitmask representing the positions of empty slots. - BitMask<uint32_t, kWidth> MatchEmpty() const { + NonIterableBitMask<uint32_t, kWidth> MaskEmpty() const { #ifdef ABSL_INTERNAL_HAVE_SSSE3 // This only works because ctrl_t::kEmpty is -128. - return BitMask<uint32_t, kWidth>( + return NonIterableBitMask<uint32_t, kWidth>( static_cast<uint32_t>(_mm_movemask_epi8(_mm_sign_epi8(ctrl, ctrl)))); #else - return Match(static_cast<h2_t>(ctrl_t::kEmpty)); + auto match = _mm_set1_epi8(static_cast<h2_t>(ctrl_t::kEmpty)); + return NonIterableBitMask<uint32_t, kWidth>( + static_cast<uint32_t>(_mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl)))); #endif } // Returns a bitmask representing the positions of empty or deleted slots. - BitMask<uint32_t, kWidth> MatchEmptyOrDeleted() const { + NonIterableBitMask<uint32_t, kWidth> MaskEmptyOrDeleted() const { auto special = _mm_set1_epi8(static_cast<uint8_t>(ctrl_t::kSentinel)); - return BitMask<uint32_t, kWidth>(static_cast<uint32_t>( + return NonIterableBitMask<uint32_t, kWidth>(static_cast<uint32_t>( _mm_movemask_epi8(_mm_cmpgt_epi8_fixed(special, ctrl)))); } @@ -579,6 +598,80 @@ struct GroupSse2Impl { }; #endif // ABSL_INTERNAL_RAW_HASH_SET_HAVE_SSE2 +#if defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN) +struct GroupAArch64Impl { + static constexpr size_t kWidth = 8; + + explicit GroupAArch64Impl(const ctrl_t* pos) { + ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos)); + } + + BitMask<uint64_t, kWidth, 3> Match(h2_t hash) const { + uint8x8_t dup = vdup_n_u8(hash); + auto mask = vceq_u8(ctrl, dup); + constexpr uint64_t msbs = 0x8080808080808080ULL; + return BitMask<uint64_t, kWidth, 3>( + vget_lane_u64(vreinterpret_u64_u8(mask), 0) & msbs); + } + + NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const { + uint64_t mask = + vget_lane_u64(vreinterpret_u64_u8( + vceq_s8(vdup_n_s8(static_cast<h2_t>(ctrl_t::kEmpty)), + vreinterpret_s8_u8(ctrl))), + 0); + return NonIterableBitMask<uint64_t, kWidth, 3>(mask); + } + + NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const { + uint64_t mask = + vget_lane_u64(vreinterpret_u64_u8(vcgt_s8( + vdup_n_s8(static_cast<int8_t>(ctrl_t::kSentinel)), + vreinterpret_s8_u8(ctrl))), + 0); + return NonIterableBitMask<uint64_t, kWidth, 3>(mask); + } + + uint32_t CountLeadingEmptyOrDeleted() const { + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0); + assert(IsEmptyOrDeleted(static_cast<ctrl_t>(mask & 0xff))); + constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL; +#if defined(ABSL_INTERNAL_HAVE_ARM_ACLE) + // cls: Count leading sign bits. + // clsll(1ull << 63) -> 0 + // clsll((1ull << 63) | (1ull << 62)) -> 1 + // clsll((1ull << 63) | (1ull << 61)) -> 0 + // clsll(~0ull) -> 63 + // clsll(1) -> 62 + // clsll(3) -> 61 + // clsll(5) -> 60 + // Note that CountLeadingEmptyOrDeleted is called when first control block + // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl + // but avoids +1 and __clsll returns result not including the high bit. Thus + // saves one cycle. + // kEmpty = -128, // 0b10000000 + // kDeleted = -2, // 0b11111110 + // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit, + // it will the highest one. + return (__clsll(__rbitll((~mask & (mask >> 7)) | gaps)) + 8) >> 3; +#else + return (TrailingZeros(((~mask & (mask >> 7)) | gaps) + 1) + 7) >> 3; +#endif // ABSL_INTERNAL_HAVE_ARM_ACLE + } + + void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const { + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0); + constexpr uint64_t msbs = 0x8080808080808080ULL; + constexpr uint64_t lsbs = 0x0101010101010101ULL; + auto x = mask & msbs; + auto res = (~x + (x >> 7)) & ~lsbs; + little_endian::Store64(dst, res); + } + + uint8x8_t ctrl; +}; +#endif // ABSL_INTERNAL_HAVE_ARM_NEON && ABSL_IS_LITTLE_ENDIAN + struct GroupPortableImpl { static constexpr size_t kWidth = 8; @@ -605,14 +698,16 @@ struct GroupPortableImpl { return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & msbs); } - BitMask<uint64_t, kWidth, 3> MatchEmpty() const { + NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const { constexpr uint64_t msbs = 0x8080808080808080ULL; - return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) & msbs); + return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) & + msbs); } - BitMask<uint64_t, kWidth, 3> MatchEmptyOrDeleted() const { + NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const { constexpr uint64_t msbs = 0x8080808080808080ULL; - return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) & msbs); + return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) & + msbs); } uint32_t CountLeadingEmptyOrDeleted() const { @@ -631,39 +726,9 @@ struct GroupPortableImpl { uint64_t ctrl; }; -#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE -struct GroupAArch64Impl : public GroupPortableImpl { - static constexpr size_t kWidth = GroupPortableImpl::kWidth; - - using GroupPortableImpl::GroupPortableImpl; - - uint32_t CountLeadingEmptyOrDeleted() const { - assert(IsEmptyOrDeleted(static_cast<ctrl_t>(ctrl & 0xff))); - constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL; - // cls: Count leading sign bits. - // clsll(1ull << 63) -> 0 - // clsll((1ull << 63) | (1ull << 62)) -> 1 - // clsll((1ull << 63) | (1ull << 61)) -> 0 - // clsll(~0ull) -> 63 - // clsll(1) -> 62 - // clsll(3) -> 61 - // clsll(5) -> 60 - // Note that CountLeadingEmptyOrDeleted is called when first control block - // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl - // but avoids +1 and __clsll returns result not including the high bit. Thus - // saves one cycle. - // kEmpty = -128, // 0b10000000 - // kDeleted = -2, // 0b11111110 - // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit, - // it will the highest one. - return (__clsll(__rbitll((~ctrl & (ctrl >> 7)) | gaps)) + 8) >> 3; - } -}; -#endif - #ifdef ABSL_INTERNAL_HAVE_SSE2 using Group = GroupSse2Impl; -#elif defined(ABSL_INTERNAL_HAVE_ARM_ACLE) +#elif defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN) using Group = GroupAArch64Impl; #else using Group = GroupPortableImpl; @@ -798,7 +863,7 @@ inline FindInfo find_first_non_full(const ctrl_t* ctrl, size_t hash, auto seq = probe(ctrl, hash, capacity); while (true) { Group g{ctrl + seq.offset()}; - auto mask = g.MatchEmptyOrDeleted(); + auto mask = g.MaskEmptyOrDeleted(); if (mask) { #if !defined(NDEBUG) // We want to add entropy even when ASLR is not enabled. @@ -1700,7 +1765,7 @@ class raw_hash_set { PolicyTraits::element(slots_ + seq.offset(i))))) return iterator_at(seq.offset(i)); } - if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return end(); + if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return end(); seq.next(); assert(seq.index() <= capacity_ && "full table!"); } @@ -1849,8 +1914,8 @@ class raw_hash_set { --size_; const size_t index = static_cast<size_t>(it.inner_.ctrl_ - ctrl_); const size_t index_before = (index - Group::kWidth) & capacity_; - const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty(); - const auto empty_before = Group(ctrl_ + index_before).MatchEmpty(); + const auto empty_after = Group(it.inner_.ctrl_).MaskEmpty(); + const auto empty_before = Group(ctrl_ + index_before).MaskEmpty(); // We count how many consecutive non empties we have to the right and to the // left of `it`. If the sum is >= kWidth then there is at least one probe @@ -2091,7 +2156,7 @@ class raw_hash_set { elem)) return true; } - if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return false; + if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return false; seq.next(); assert(seq.index() <= capacity_ && "full table!"); } @@ -2127,7 +2192,7 @@ class raw_hash_set { PolicyTraits::element(slots_ + seq.offset(i))))) return {seq.offset(i), false}; } - if (ABSL_PREDICT_TRUE(g.MatchEmpty())) break; + if (ABSL_PREDICT_TRUE(g.MaskEmpty())) break; seq.next(); assert(seq.index() <= capacity_ && "full table!"); } @@ -2272,7 +2337,7 @@ struct HashtableDebugAccess<Set, absl::void_t<typename Set::raw_hash_set>> { return num_probes; ++num_probes; } - if (g.MatchEmpty()) return num_probes; + if (g.MaskEmpty()) return num_probes; seq.next(); ++num_probes; } diff --git a/absl/container/internal/raw_hash_set_benchmark.cc b/absl/container/internal/raw_hash_set_benchmark.cc index 146ef433..47dc9048 100644 --- a/absl/container/internal/raw_hash_set_benchmark.cc +++ b/absl/container/internal/raw_hash_set_benchmark.cc @@ -336,27 +336,27 @@ void BM_Group_Match(benchmark::State& state) { } BENCHMARK(BM_Group_Match); -void BM_Group_MatchEmpty(benchmark::State& state) { +void BM_Group_MaskEmpty(benchmark::State& state) { std::array<ctrl_t, Group::kWidth> group; Iota(group.begin(), group.end(), -4); Group g{group.data()}; for (auto _ : state) { ::benchmark::DoNotOptimize(g); - ::benchmark::DoNotOptimize(g.MatchEmpty()); + ::benchmark::DoNotOptimize(g.MaskEmpty()); } } -BENCHMARK(BM_Group_MatchEmpty); +BENCHMARK(BM_Group_MaskEmpty); -void BM_Group_MatchEmptyOrDeleted(benchmark::State& state) { +void BM_Group_MaskEmptyOrDeleted(benchmark::State& state) { std::array<ctrl_t, Group::kWidth> group; Iota(group.begin(), group.end(), -4); Group g{group.data()}; for (auto _ : state) { ::benchmark::DoNotOptimize(g); - ::benchmark::DoNotOptimize(g.MatchEmptyOrDeleted()); + ::benchmark::DoNotOptimize(g.MaskEmptyOrDeleted()); } } -BENCHMARK(BM_Group_MatchEmptyOrDeleted); +BENCHMARK(BM_Group_MaskEmptyOrDeleted); void BM_Group_CountLeadingEmptyOrDeleted(benchmark::State& state) { std::array<ctrl_t, Group::kWidth> group; @@ -375,7 +375,7 @@ void BM_Group_MatchFirstEmptyOrDeleted(benchmark::State& state) { Group g{group.data()}; for (auto _ : state) { ::benchmark::DoNotOptimize(g); - ::benchmark::DoNotOptimize(*g.MatchEmptyOrDeleted()); + ::benchmark::DoNotOptimize(g.MaskEmptyOrDeleted().LowestBitSet()); } } BENCHMARK(BM_Group_MatchFirstEmptyOrDeleted); diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc index c79f8641..dc6bc3d2 100644 --- a/absl/container/internal/raw_hash_set_test.cc +++ b/absl/container/internal/raw_hash_set_test.cc @@ -195,35 +195,39 @@ TEST(Group, Match) { } } -TEST(Group, MatchEmpty) { +TEST(Group, MaskEmpty) { if (Group::kWidth == 16) { ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kDeleted, CtrlT(3), ctrl_t::kEmpty, CtrlT(5), ctrl_t::kSentinel, CtrlT(7), CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)}; - EXPECT_THAT(Group{group}.MatchEmpty(), ElementsAre(0, 4)); + EXPECT_THAT(Group{group}.MaskEmpty().LowestBitSet(), 0); + EXPECT_THAT(Group{group}.MaskEmpty().HighestBitSet(), 4); } else if (Group::kWidth == 8) { ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), CtrlT(2), ctrl_t::kDeleted, CtrlT(2), CtrlT(1), ctrl_t::kSentinel, CtrlT(1)}; - EXPECT_THAT(Group{group}.MatchEmpty(), ElementsAre(0)); + EXPECT_THAT(Group{group}.MaskEmpty().LowestBitSet(), 0); + EXPECT_THAT(Group{group}.MaskEmpty().HighestBitSet(), 0); } else { FAIL() << "No test coverage for Group::kWidth==" << Group::kWidth; } } -TEST(Group, MatchEmptyOrDeleted) { +TEST(Group, MaskEmptyOrDeleted) { if (Group::kWidth == 16) { - ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kDeleted, CtrlT(3), - ctrl_t::kEmpty, CtrlT(5), ctrl_t::kSentinel, CtrlT(7), - CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1), - CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)}; - EXPECT_THAT(Group{group}.MatchEmptyOrDeleted(), ElementsAre(0, 2, 4)); + ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kEmpty, CtrlT(3), + ctrl_t::kDeleted, CtrlT(5), ctrl_t::kSentinel, CtrlT(7), + CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1), + CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)}; + EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().LowestBitSet(), 0); + EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().HighestBitSet(), 4); } else if (Group::kWidth == 8) { ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), CtrlT(2), ctrl_t::kDeleted, CtrlT(2), CtrlT(1), ctrl_t::kSentinel, CtrlT(1)}; - EXPECT_THAT(Group{group}.MatchEmptyOrDeleted(), ElementsAre(0, 3)); + EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().LowestBitSet(), 0); + EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().HighestBitSet(), 3); } else { FAIL() << "No test coverage for Group::kWidth==" << Group::kWidth; } |