summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--absl/base/config.h9
-rw-r--r--absl/container/internal/raw_hash_set.h227
-rw-r--r--absl/container/internal/raw_hash_set_benchmark.cc14
-rw-r--r--absl/container/internal/raw_hash_set_test.cc24
4 files changed, 176 insertions, 98 deletions
diff --git a/absl/base/config.h b/absl/base/config.h
index 4223629e..802529fc 100644
--- a/absl/base/config.h
+++ b/absl/base/config.h
@@ -898,4 +898,13 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
#define ABSL_INTERNAL_HAVE_ARM_ACLE 1
#endif
+// ABSL_INTERNAL_HAVE_ARM_NEON is used for compile-time detection of NEON (ARM
+// SIMD).
+#ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+#error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+#elif defined(__ARM_NEON)
+#define ABSL_INTERNAL_HAVE_ARM_NEON 1
+#endif
+
+
#endif // ABSL_BASE_CONFIG_H_
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index d503bc00..2756ce1b 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -184,6 +184,14 @@
#include <intrin.h>
#endif
+#ifdef __ARM_NEON
+#include <arm_neon.h>
+#endif
+
+#ifdef __ARM_ACLE
+#include <arm_acle.h>
+#endif
+
#include <algorithm>
#include <cmath>
#include <cstdint>
@@ -211,10 +219,6 @@
#include "absl/numeric/bits.h"
#include "absl/utility/utility.h"
-#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
-#include <arm_acle.h>
-#endif
-
namespace absl {
ABSL_NAMESPACE_BEGIN
namespace container_internal {
@@ -323,36 +327,15 @@ uint32_t TrailingZeros(T x) {
// controlled by `SignificantBits` and `Shift`. `SignificantBits` is the number
// of abstract bits in the bitset, while `Shift` is the log-base-two of the
// width of an abstract bit in the representation.
-//
-// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
-// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
-// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
-// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
-//
-// For example:
-// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
-// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
+// This mask provides operations for any number of real bits set in an abstract
+// bit. To add iteration on top of that, implementation must guarantee no more
+// than one real bit is set in an abstract bit.
template <class T, int SignificantBits, int Shift = 0>
-class BitMask {
- static_assert(std::is_unsigned<T>::value, "");
- static_assert(Shift == 0 || Shift == 3, "");
-
+class NonIterableBitMask {
public:
- // BitMask is an iterator over the indices of its abstract bits.
- using value_type = int;
- using iterator = BitMask;
- using const_iterator = BitMask;
-
- explicit BitMask(T mask) : mask_(mask) {}
- BitMask& operator++() {
- mask_ &= (mask_ - 1);
- return *this;
- }
- explicit operator bool() const { return mask_ != 0; }
- uint32_t operator*() const { return LowestBitSet(); }
+ explicit NonIterableBitMask(T mask) : mask_(mask) {}
- BitMask begin() const { return *this; }
- BitMask end() const { return BitMask(0); }
+ explicit operator bool() const { return this->mask_ != 0; }
// Returns the index of the lowest *abstract* bit set in `self`.
uint32_t LowestBitSet() const {
@@ -376,6 +359,42 @@ class BitMask {
return static_cast<uint32_t>(countl_zero(mask_ << extra_bits)) >> Shift;
}
+ T mask_;
+};
+
+// Mask that can be iterable
+//
+// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
+// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
+// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
+// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
+//
+// For example:
+// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
+// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
+template <class T, int SignificantBits, int Shift = 0>
+class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> {
+ using Base = NonIterableBitMask<T, SignificantBits, Shift>;
+ static_assert(std::is_unsigned<T>::value, "");
+ static_assert(Shift == 0 || Shift == 3, "");
+
+ public:
+ explicit BitMask(T mask) : Base(mask) {}
+ // BitMask is an iterator over the indices of its abstract bits.
+ using value_type = int;
+ using iterator = BitMask;
+ using const_iterator = BitMask;
+
+ BitMask& operator++() {
+ this->mask_ &= (this->mask_ - 1);
+ return *this;
+ }
+
+ uint32_t operator*() const { return Base::LowestBitSet(); }
+
+ BitMask begin() const { return *this; }
+ BitMask end() const { return BitMask(0); }
+
private:
friend bool operator==(const BitMask& a, const BitMask& b) {
return a.mask_ == b.mask_;
@@ -383,8 +402,6 @@ class BitMask {
friend bool operator!=(const BitMask& a, const BitMask& b) {
return a.mask_ != b.mask_;
}
-
- T mask_;
};
using h2_t = uint8_t;
@@ -433,7 +450,7 @@ static_assert(
static_cast<int8_t>(ctrl_t::kSentinel) & 0x7F) != 0,
"ctrl_t::kEmpty and ctrl_t::kDeleted must share an unset bit that is not "
"shared by ctrl_t::kSentinel to make the scalar test for "
- "MatchEmptyOrDeleted() efficient");
+ "MaskEmptyOrDeleted() efficient");
static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2),
"ctrl_t::kDeleted must be -2 to make the implementation of "
"ConvertSpecialToEmptyAndFullToDeleted efficient");
@@ -538,20 +555,22 @@ struct GroupSse2Impl {
}
// Returns a bitmask representing the positions of empty slots.
- BitMask<uint32_t, kWidth> MatchEmpty() const {
+ NonIterableBitMask<uint32_t, kWidth> MaskEmpty() const {
#ifdef ABSL_INTERNAL_HAVE_SSSE3
// This only works because ctrl_t::kEmpty is -128.
- return BitMask<uint32_t, kWidth>(
+ return NonIterableBitMask<uint32_t, kWidth>(
static_cast<uint32_t>(_mm_movemask_epi8(_mm_sign_epi8(ctrl, ctrl))));
#else
- return Match(static_cast<h2_t>(ctrl_t::kEmpty));
+ auto match = _mm_set1_epi8(static_cast<h2_t>(ctrl_t::kEmpty));
+ return NonIterableBitMask<uint32_t, kWidth>(
+ static_cast<uint32_t>(_mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl))));
#endif
}
// Returns a bitmask representing the positions of empty or deleted slots.
- BitMask<uint32_t, kWidth> MatchEmptyOrDeleted() const {
+ NonIterableBitMask<uint32_t, kWidth> MaskEmptyOrDeleted() const {
auto special = _mm_set1_epi8(static_cast<uint8_t>(ctrl_t::kSentinel));
- return BitMask<uint32_t, kWidth>(static_cast<uint32_t>(
+ return NonIterableBitMask<uint32_t, kWidth>(static_cast<uint32_t>(
_mm_movemask_epi8(_mm_cmpgt_epi8_fixed(special, ctrl))));
}
@@ -579,6 +598,80 @@ struct GroupSse2Impl {
};
#endif // ABSL_INTERNAL_RAW_HASH_SET_HAVE_SSE2
+#if defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
+struct GroupAArch64Impl {
+ static constexpr size_t kWidth = 8;
+
+ explicit GroupAArch64Impl(const ctrl_t* pos) {
+ ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos));
+ }
+
+ BitMask<uint64_t, kWidth, 3> Match(h2_t hash) const {
+ uint8x8_t dup = vdup_n_u8(hash);
+ auto mask = vceq_u8(ctrl, dup);
+ constexpr uint64_t msbs = 0x8080808080808080ULL;
+ return BitMask<uint64_t, kWidth, 3>(
+ vget_lane_u64(vreinterpret_u64_u8(mask), 0) & msbs);
+ }
+
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
+ uint64_t mask =
+ vget_lane_u64(vreinterpret_u64_u8(
+ vceq_s8(vdup_n_s8(static_cast<h2_t>(ctrl_t::kEmpty)),
+ vreinterpret_s8_u8(ctrl))),
+ 0);
+ return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
+ }
+
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
+ uint64_t mask =
+ vget_lane_u64(vreinterpret_u64_u8(vcgt_s8(
+ vdup_n_s8(static_cast<int8_t>(ctrl_t::kSentinel)),
+ vreinterpret_s8_u8(ctrl))),
+ 0);
+ return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
+ }
+
+ uint32_t CountLeadingEmptyOrDeleted() const {
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
+ assert(IsEmptyOrDeleted(static_cast<ctrl_t>(mask & 0xff)));
+ constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
+#if defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
+ // cls: Count leading sign bits.
+ // clsll(1ull << 63) -> 0
+ // clsll((1ull << 63) | (1ull << 62)) -> 1
+ // clsll((1ull << 63) | (1ull << 61)) -> 0
+ // clsll(~0ull) -> 63
+ // clsll(1) -> 62
+ // clsll(3) -> 61
+ // clsll(5) -> 60
+ // Note that CountLeadingEmptyOrDeleted is called when first control block
+ // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
+ // but avoids +1 and __clsll returns result not including the high bit. Thus
+ // saves one cycle.
+ // kEmpty = -128, // 0b10000000
+ // kDeleted = -2, // 0b11111110
+ // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
+ // it will the highest one.
+ return (__clsll(__rbitll((~mask & (mask >> 7)) | gaps)) + 8) >> 3;
+#else
+ return (TrailingZeros(((~mask & (mask >> 7)) | gaps) + 1) + 7) >> 3;
+#endif // ABSL_INTERNAL_HAVE_ARM_ACLE
+ }
+
+ void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const {
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
+ constexpr uint64_t msbs = 0x8080808080808080ULL;
+ constexpr uint64_t lsbs = 0x0101010101010101ULL;
+ auto x = mask & msbs;
+ auto res = (~x + (x >> 7)) & ~lsbs;
+ little_endian::Store64(dst, res);
+ }
+
+ uint8x8_t ctrl;
+};
+#endif // ABSL_INTERNAL_HAVE_ARM_NEON && ABSL_IS_LITTLE_ENDIAN
+
struct GroupPortableImpl {
static constexpr size_t kWidth = 8;
@@ -605,14 +698,16 @@ struct GroupPortableImpl {
return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & msbs);
}
- BitMask<uint64_t, kWidth, 3> MatchEmpty() const {
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
- return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) & msbs);
+ return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) &
+ msbs);
}
- BitMask<uint64_t, kWidth, 3> MatchEmptyOrDeleted() const {
+ NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
- return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) & msbs);
+ return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) &
+ msbs);
}
uint32_t CountLeadingEmptyOrDeleted() const {
@@ -631,39 +726,9 @@ struct GroupPortableImpl {
uint64_t ctrl;
};
-#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
-struct GroupAArch64Impl : public GroupPortableImpl {
- static constexpr size_t kWidth = GroupPortableImpl::kWidth;
-
- using GroupPortableImpl::GroupPortableImpl;
-
- uint32_t CountLeadingEmptyOrDeleted() const {
- assert(IsEmptyOrDeleted(static_cast<ctrl_t>(ctrl & 0xff)));
- constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
- // cls: Count leading sign bits.
- // clsll(1ull << 63) -> 0
- // clsll((1ull << 63) | (1ull << 62)) -> 1
- // clsll((1ull << 63) | (1ull << 61)) -> 0
- // clsll(~0ull) -> 63
- // clsll(1) -> 62
- // clsll(3) -> 61
- // clsll(5) -> 60
- // Note that CountLeadingEmptyOrDeleted is called when first control block
- // is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
- // but avoids +1 and __clsll returns result not including the high bit. Thus
- // saves one cycle.
- // kEmpty = -128, // 0b10000000
- // kDeleted = -2, // 0b11111110
- // ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
- // it will the highest one.
- return (__clsll(__rbitll((~ctrl & (ctrl >> 7)) | gaps)) + 8) >> 3;
- }
-};
-#endif
-
#ifdef ABSL_INTERNAL_HAVE_SSE2
using Group = GroupSse2Impl;
-#elif defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
+#elif defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
using Group = GroupAArch64Impl;
#else
using Group = GroupPortableImpl;
@@ -798,7 +863,7 @@ inline FindInfo find_first_non_full(const ctrl_t* ctrl, size_t hash,
auto seq = probe(ctrl, hash, capacity);
while (true) {
Group g{ctrl + seq.offset()};
- auto mask = g.MatchEmptyOrDeleted();
+ auto mask = g.MaskEmptyOrDeleted();
if (mask) {
#if !defined(NDEBUG)
// We want to add entropy even when ASLR is not enabled.
@@ -1700,7 +1765,7 @@ class raw_hash_set {
PolicyTraits::element(slots_ + seq.offset(i)))))
return iterator_at(seq.offset(i));
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return end();
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return end();
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -1849,8 +1914,8 @@ class raw_hash_set {
--size_;
const size_t index = static_cast<size_t>(it.inner_.ctrl_ - ctrl_);
const size_t index_before = (index - Group::kWidth) & capacity_;
- const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty();
- const auto empty_before = Group(ctrl_ + index_before).MatchEmpty();
+ const auto empty_after = Group(it.inner_.ctrl_).MaskEmpty();
+ const auto empty_before = Group(ctrl_ + index_before).MaskEmpty();
// We count how many consecutive non empties we have to the right and to the
// left of `it`. If the sum is >= kWidth then there is at least one probe
@@ -2091,7 +2156,7 @@ class raw_hash_set {
elem))
return true;
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return false;
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return false;
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -2127,7 +2192,7 @@ class raw_hash_set {
PolicyTraits::element(slots_ + seq.offset(i)))))
return {seq.offset(i), false};
}
- if (ABSL_PREDICT_TRUE(g.MatchEmpty())) break;
+ if (ABSL_PREDICT_TRUE(g.MaskEmpty())) break;
seq.next();
assert(seq.index() <= capacity_ && "full table!");
}
@@ -2272,7 +2337,7 @@ struct HashtableDebugAccess<Set, absl::void_t<typename Set::raw_hash_set>> {
return num_probes;
++num_probes;
}
- if (g.MatchEmpty()) return num_probes;
+ if (g.MaskEmpty()) return num_probes;
seq.next();
++num_probes;
}
diff --git a/absl/container/internal/raw_hash_set_benchmark.cc b/absl/container/internal/raw_hash_set_benchmark.cc
index 146ef433..47dc9048 100644
--- a/absl/container/internal/raw_hash_set_benchmark.cc
+++ b/absl/container/internal/raw_hash_set_benchmark.cc
@@ -336,27 +336,27 @@ void BM_Group_Match(benchmark::State& state) {
}
BENCHMARK(BM_Group_Match);
-void BM_Group_MatchEmpty(benchmark::State& state) {
+void BM_Group_MaskEmpty(benchmark::State& state) {
std::array<ctrl_t, Group::kWidth> group;
Iota(group.begin(), group.end(), -4);
Group g{group.data()};
for (auto _ : state) {
::benchmark::DoNotOptimize(g);
- ::benchmark::DoNotOptimize(g.MatchEmpty());
+ ::benchmark::DoNotOptimize(g.MaskEmpty());
}
}
-BENCHMARK(BM_Group_MatchEmpty);
+BENCHMARK(BM_Group_MaskEmpty);
-void BM_Group_MatchEmptyOrDeleted(benchmark::State& state) {
+void BM_Group_MaskEmptyOrDeleted(benchmark::State& state) {
std::array<ctrl_t, Group::kWidth> group;
Iota(group.begin(), group.end(), -4);
Group g{group.data()};
for (auto _ : state) {
::benchmark::DoNotOptimize(g);
- ::benchmark::DoNotOptimize(g.MatchEmptyOrDeleted());
+ ::benchmark::DoNotOptimize(g.MaskEmptyOrDeleted());
}
}
-BENCHMARK(BM_Group_MatchEmptyOrDeleted);
+BENCHMARK(BM_Group_MaskEmptyOrDeleted);
void BM_Group_CountLeadingEmptyOrDeleted(benchmark::State& state) {
std::array<ctrl_t, Group::kWidth> group;
@@ -375,7 +375,7 @@ void BM_Group_MatchFirstEmptyOrDeleted(benchmark::State& state) {
Group g{group.data()};
for (auto _ : state) {
::benchmark::DoNotOptimize(g);
- ::benchmark::DoNotOptimize(*g.MatchEmptyOrDeleted());
+ ::benchmark::DoNotOptimize(g.MaskEmptyOrDeleted().LowestBitSet());
}
}
BENCHMARK(BM_Group_MatchFirstEmptyOrDeleted);
diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc
index c79f8641..dc6bc3d2 100644
--- a/absl/container/internal/raw_hash_set_test.cc
+++ b/absl/container/internal/raw_hash_set_test.cc
@@ -195,35 +195,39 @@ TEST(Group, Match) {
}
}
-TEST(Group, MatchEmpty) {
+TEST(Group, MaskEmpty) {
if (Group::kWidth == 16) {
ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kDeleted, CtrlT(3),
ctrl_t::kEmpty, CtrlT(5), ctrl_t::kSentinel, CtrlT(7),
CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1),
CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)};
- EXPECT_THAT(Group{group}.MatchEmpty(), ElementsAre(0, 4));
+ EXPECT_THAT(Group{group}.MaskEmpty().LowestBitSet(), 0);
+ EXPECT_THAT(Group{group}.MaskEmpty().HighestBitSet(), 4);
} else if (Group::kWidth == 8) {
ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), CtrlT(2),
ctrl_t::kDeleted, CtrlT(2), CtrlT(1),
ctrl_t::kSentinel, CtrlT(1)};
- EXPECT_THAT(Group{group}.MatchEmpty(), ElementsAre(0));
+ EXPECT_THAT(Group{group}.MaskEmpty().LowestBitSet(), 0);
+ EXPECT_THAT(Group{group}.MaskEmpty().HighestBitSet(), 0);
} else {
FAIL() << "No test coverage for Group::kWidth==" << Group::kWidth;
}
}
-TEST(Group, MatchEmptyOrDeleted) {
+TEST(Group, MaskEmptyOrDeleted) {
if (Group::kWidth == 16) {
- ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kDeleted, CtrlT(3),
- ctrl_t::kEmpty, CtrlT(5), ctrl_t::kSentinel, CtrlT(7),
- CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1),
- CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)};
- EXPECT_THAT(Group{group}.MatchEmptyOrDeleted(), ElementsAre(0, 2, 4));
+ ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), ctrl_t::kEmpty, CtrlT(3),
+ ctrl_t::kDeleted, CtrlT(5), ctrl_t::kSentinel, CtrlT(7),
+ CtrlT(7), CtrlT(5), CtrlT(3), CtrlT(1),
+ CtrlT(1), CtrlT(1), CtrlT(1), CtrlT(1)};
+ EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().LowestBitSet(), 0);
+ EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().HighestBitSet(), 4);
} else if (Group::kWidth == 8) {
ctrl_t group[] = {ctrl_t::kEmpty, CtrlT(1), CtrlT(2),
ctrl_t::kDeleted, CtrlT(2), CtrlT(1),
ctrl_t::kSentinel, CtrlT(1)};
- EXPECT_THAT(Group{group}.MatchEmptyOrDeleted(), ElementsAre(0, 3));
+ EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().LowestBitSet(), 0);
+ EXPECT_THAT(Group{group}.MaskEmptyOrDeleted().HighestBitSet(), 3);
} else {
FAIL() << "No test coverage for Group::kWidth==" << Group::kWidth;
}