summaryrefslogtreecommitdiff
path: root/absl/container/internal/raw_hash_set.h
diff options
context:
space:
mode:
authorGravatar Abseil Team <absl-team@google.com>2022-06-06 09:28:31 -0700
committerGravatar Copybara-Service <copybara-worker@google.com>2022-06-06 09:29:27 -0700
commit6481443560a92d0a3a55a31807de0cd712cd4f88 (patch)
treeb13d0a400f72cc4d0acc3a35f2ff73b2499a127f /absl/container/internal/raw_hash_set.h
parent48419595d31609762985a6b08be504ebe6d593e7 (diff)
Optimize SwissMap for ARM by 3-8% for all operations
https://pastebin.com/CmnzwUFN The key idea is to avoid using 16 byte NEON and use 8 byte NEON which has lower latency for BitMask::Match. Even though 16 byte NEON achieves higher throughput, in SwissMap it's very important to catch these Matches with low latency as probing on average happens at most once. I also introduced NonIterableMask as ARM has really great cbnz instructions and additional AND on scalar mask had 1 extra latency cycle PiperOrigin-RevId: 453216147 Change-Id: I842c50d323954f8383ae156491232ced55aacb78
Diffstat (limited to 'absl/container/internal/raw_hash_set.h')
-rw-r--r--absl/container/internal/raw_hash_set.h227
1 files changed, 146 insertions, 81 deletions
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index d503bc00..2756ce1b 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -184,6 +184,14 @@
#include <intrin.h>
#endif
+#ifdef __ARM_NEON
+#include <arm_neon.h>
+#endif
+
+#ifdef __ARM_ACLE
+#include <arm_acle.h>
+#endif
+
#include <algorithm>
#include <cmath>
#include <cstdint>
@@ -211,10 +219,6 @@
#include "absl/numeric/bits.h"
#include "absl/utility/utility.h"
-#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
-#include <arm_acle.h>
-#endif
-
namespace absl {
ABSL_NAMESPACE_BEGIN
namespace container_internal {
@@ -323,36 +327,15 @@ uint32_t TrailingZeros(T x) {
// controlled by `SignificantBits` and `Shift`. `SignificantBits` is the number
// of abstract bits in the bitset, while `Shift` is the log-base-two of the
// width of an abstract bit in the representation.
-//
-// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
-// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
-// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
-// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
-//
-// For example:
-// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
-// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
+// This mask provides operations for any number of real bits set in an abstract
+// bit. To add iteration on top of that, implementation must guarantee no more
+// than one real bit is set in an abstract bit.
template <class T, int SignificantBits, int Shift = 0>
-class BitMask {
- static_assert(std::is_unsigned<T>::value, "");
- static_assert(Shift == 0 || Shift == 3, "");
-
+class NonIterableBitMask {
public:
- // BitMask is an iterator over the indices of its abstract bits.
- using value_type = int;
- using iterator = BitMask;
- using const_iterator = BitMask;
-
- explicit BitMask(T mask) : mask_(mask) {}
- BitMask& operator++() {
- mask_ &= (mask_ - 1);
- return *this;
- }
- explicit operator bool() const { return mask_ != 0; }
- uint32_t operator*() const { return LowestBitSet(); }
+ explicit NonIterableBitMask(T mask) : mask_(mask) {}
- BitMask begin() const { return *this; }
- BitMask end() const { return BitMask(0); }
+ explicit operator bool() const { return this->mask_ != 0; }
// Returns the index of the lowest *abstract* bit set in `self`.
uint32_t LowestBitSet() const {
@@ -376,6 +359,42 @@ class BitMask {
return static_cast<uint32_t>(countl_zero(mask_ << extra_bits)) >> Shift;
}
+ T mask_;
+};
+
+// Mask that can be iterable
+//
+// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
+// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
+// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
+// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
+//
+// For example:
+// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
+// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
+template <class T, int SignificantBits, int Shift = 0>
+class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> {
+ using Base = NonIterableBitMask<T, SignificantBits, Shift>;
+ static_assert(std::is_unsigned<T>::value, "");
+ static_assert(Shift == 0 || Shift == 3, "");
+
+ public:
+ explicit BitMask(T mask) : Base(mask) {}
+ // BitMask is an iterator over the indices of its abstract bits.
+ using value_type = int;
+ using iterator = BitMask;
+ using const_iterator = BitMask;
+
+ BitMask& operator++() {
+ this->mask_ &= (this->mask_ - 1);
+ return *this;
+ }
+
+ uint32_t operator*() const { return Base::LowestBitSet(); }
+
+ BitMask begin() const { return *this; }
+ BitMask end() const { return BitMask(0); }
+
private:
friend bool operator==(const BitMask& a, const BitMask& b) {
return a.mask_ == b.mask_;
@@ -383,8 +402,6 @@ class BitMask {
friend bool operator!=(const BitMask& a, const BitMask& b) {
return a.mask_ != b.mask_;
}
-
- T mask_;
};
using h2_t = uint8_t;
@@ -433,7 +450,7 @@ static_assert(
static_cast<int8_t>(ctrl_t::kSentinel) & 0x7F) != 0,
"ctrl_t::kEmpty and ctrl_t::kDeleted must share an unset bit that is not "
"shared by ctrl_t::kSentinel to make the scalar test for "
- "MatchEmptyOrDeleted() efficient");
+ "MaskEmptyOrDeleted() efficient");
static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2),
"ctrl_t::kDeleted must be -2 to make the implementation of "
"ConvertSpecialToEmptyAndFullToDeleted efficient");
@@ -538,20 +555,22 @@ struct GroupSse2Impl {
}
// Returns a bitmask representing the positions of empty slots.
- BitMask<uint32_t, kWidth> MatchEmpty() const {
+ NonIterableBitMask<uint32_t, kWidth> MaskEmpty() const {
#ifdef ABSL_INTERNAL_HAVE_SSSE3
// This only works because ctrl_t::kEmpty is -128.
- return BitMask<uint32_t, kWidth>(
+ return NonIterableBitMask<uint32_t, kWidth>(
static_cast<uint32_t>(_mm_movemask_epi8(_mm_sign_epi8(ctrl, ctrl))));
#else
- return Match(static_cast<h2_t>(ctrl_t::kEmpty));
+ auto match = _mm_set1_epi8(static_cast<h2_t>(ctrl_t::kEmpty));
+ return NonIterableBitMask<uint32_t, kWidth>(
+ static_cast<uint32_t>(_mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl))));
#endif
}
// Returns a bitmask representing the positions of empty or deleted slots.
- BitMask<uint32_t, kWidth> MatchEmptyOrDeleted() const {
+ NonIterableBitMask<uint32_t, kWidth> MaskEmptyOrDeleted() const {
auto special = _mm_set1_epi8(static_cast<uint8_t>(ctrl_t::kSentinel));
- return BitMask<uint32_t, kWidth>(static_cast<uint32_t>(
+ return NonIterableBitMask<uint32_t, kWidth>(static_cast<uint32_t>(
_mm_movemask_epi8(_mm_cmpgt_epi8_fixed(special, ctrl))));
}
@@ -579,6 +598,80 @@ struct GroupSse2Impl {
};
#endif // ABSL_INTERNAL_RAW_HASH_SET_HAVE_SSE2
+#if defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
+struct GroupAArch64Impl {
+ static constexpr size_t kWidth = 8;
+
+ explicit GroupAArch64Impl(const ctrl_t* pos) {
+ ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos));
+ }
+
+ BitMask<uint64_t, kWidth, 3> Match(h2_t hash) const {
+ uint8x8_t dup = vdup_n_u8(hash);
+ auto mask = vceq_u8(ctrl, dup);
+ constexpr uint64_t msbs = 0x8080808080808080ULL;
+ return BitMask<uint64_t, kWidth, 3>(
+ vget_lane_u64(vreinterpret_u64_u8(mask), 0) & msbs);
+ }
+
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
+ uint64_t mask =
+ vget_lane_u64(vreinterpret_u64_u8(
+ vceq_s8(vdup_n_s8(static_cast<h2_t>(ctrl_t::kEmpty)),
+ vreinterpret_s8_u8(ctrl))),
+ 0);
+ return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
+ }
+
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
+ uint64_t mask =
+ vget_lane_u64(vreinterpret_u64_u8(vcgt_s8(
+ vdup_n_s8(static_cast<int8_t>(ctrl_t::kSentinel)),
+ vreinterpret_s8_u8(ctrl))),
+ 0);
+ return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
+ }
+
+ uint32_t CountLeadingEmptyOrDeleted() const {
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
+ assert(IsEmptyOrDeleted(static_cast<ctrl_t>(mask & 0xff)));
+ constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
+#if defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
+ // cls: Count leading sign bits.
+ // clsll(1ull << 63) -> 0
+ // clsll((1ull << 63) | (1ull << 62)) -> 1
+ // clsll((1ull << 63) | (1ull << 61)) -> 0
+ // clsll(~0ull) -> 63
+ // clsll(1) -> 62
+ // clsll(3) -> 61
+ // clsll(5) -> 60
+ // Note that CountLeadingEmptyOrDeleted is called when first control block
+ // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
+ // but avoids +1 and __clsll returns result not including the high bit. Thus
+ // saves one cycle.
+ // kEmpty = -128, // 0b10000000
+ // kDeleted = -2, // 0b11111110
+ // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
+ // it will the highest one.
+ return (__clsll(__rbitll((~mask & (mask >> 7)) | gaps)) + 8) >> 3;
+#else
+ return (TrailingZeros(((~mask & (mask >> 7)) | gaps) + 1) + 7) >> 3;
+#endif // ABSL_INTERNAL_HAVE_ARM_ACLE
+ }
+
+ void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const {
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
+ constexpr uint64_t msbs = 0x8080808080808080ULL;
+ constexpr uint64_t lsbs = 0x0101010101010101ULL;
+ auto x = mask & msbs;
+ auto res = (~x + (x >> 7)) & ~lsbs;
+ little_endian::Store64(dst, res);
+ }
+
+ uint8x8_t ctrl;
+};
+#endif // ABSL_INTERNAL_HAVE_ARM_NEON && ABSL_IS_LITTLE_ENDIAN
+
struct GroupPortableImpl {
static constexpr size_t kWidth = 8;
@@ -605,14 +698,16 @@ struct GroupPortableImpl {
return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & msbs);
}
- BitMask<uint64_t, kWidth, 3> MatchEmpty() const {
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
- return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) & msbs);
+ return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) &
+ msbs);
}
- BitMask<uint64_t, kWidth, 3> MatchEmptyOrDeleted() const {
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
- return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) & msbs);
+ return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) &
+ msbs);
}
uint32_t CountLeadingEmptyOrDeleted() const {
@@ -631,39 +726,9 @@ struct GroupPortableImpl {
uint64_t ctrl;
};
-#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
-struct GroupAArch64Impl : public GroupPortableImpl {
- static constexpr size_t kWidth = GroupPortableImpl::kWidth;
-
- using GroupPortableImpl::GroupPortableImpl;
-
- uint32_t CountLeadingEmptyOrDeleted() const {
- assert(IsEmptyOrDeleted(static_cast<ctrl_t>(ctrl & 0xff)));
- constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
- // cls: Count leading sign bits.
- // clsll(1ull << 63) -> 0
- // clsll((1ull << 63) | (1ull << 62)) -> 1
- // clsll((1ull << 63) | (1ull << 61)) -> 0
- // clsll(~0ull) -> 63
- // clsll(1) -> 62
- // clsll(3) -> 61
- // clsll(5) -> 60
- // Note that CountLeadingEmptyOrDeleted is called when first control block
- // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
- // but avoids +1 and __clsll returns result not including the high bit. Thus
- // saves one cycle.
- // kEmpty = -128, // 0b10000000
- // kDeleted = -2, // 0b11111110
- // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
- // it will the highest one.
- return (__clsll(__rbitll((~ctrl & (ctrl >> 7)) | gaps)) + 8) >> 3;
- }
-};
-#endif
-
#ifdef ABSL_INTERNAL_HAVE_SSE2
using Group = GroupSse2Impl;
-#elif defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
+#elif defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
using Group = GroupAArch64Impl;
#else
using Group = GroupPortableImpl;
@@ -798,7 +863,7 @@ inline FindInfo find_first_non_full(const ctrl_t* ctrl, size_t hash,
auto seq = probe(ctrl, hash, capacity);
while (true) {
Group g{ctrl + seq.offset()};
- auto mask = g.MatchEmptyOrDeleted();
+ auto mask = g.MaskEmptyOrDeleted();
if (mask) {
#if !defined(NDEBUG)
// We want to add entropy even when ASLR is not enabled.
@@ -1700,7 +1765,7 @@ class raw_hash_set {
PolicyTraits::element(slots_ + seq.offset(i)))))
return iterator_at(seq.offset(i));
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return end();
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return end();
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -1849,8 +1914,8 @@ class raw_hash_set {
--size_;
const size_t index = static_cast<size_t>(it.inner_.ctrl_ - ctrl_);
const size_t index_before = (index - Group::kWidth) & capacity_;
- const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty();
- const auto empty_before = Group(ctrl_ + index_before).MatchEmpty();
+ const auto empty_after = Group(it.inner_.ctrl_).MaskEmpty();
+ const auto empty_before = Group(ctrl_ + index_before).MaskEmpty();
// We count how many consecutive non empties we have to the right and to the
// left of `it`. If the sum is >= kWidth then there is at least one probe
@@ -2091,7 +2156,7 @@ class raw_hash_set {
elem))
return true;
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return false;
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return false;
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -2127,7 +2192,7 @@ class raw_hash_set {
PolicyTraits::element(slots_ + seq.offset(i)))))
return {seq.offset(i), false};
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) break;
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) break;
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -2272,7 +2337,7 @@ struct HashtableDebugAccess<Set, absl::void_t<typename Set::raw_hash_set>> {
return num_probes;
++num_probes;
}
- if (g.MatchEmpty()) return num_probes;
+ if (g.MaskEmpty()) return num_probes;
seq.next();
++num_probes;
}