summaryrefslogtreecommitdiff
path: root/absl/algorithm
diff options
context:
space:
mode:
authorGravatar Eric Astor <epastor@google.com>2023-12-21 08:11:01 -0800
committerGravatar Copybara-Service <copybara-worker@google.com>2023-12-21 08:12:11 -0800
commit258e5a15759cc3d122d4a4826bc499af91d40aa9 (patch)
tree0696c01c1d40217b8c339a3e81418dace1e10640 /absl/algorithm
parent794352a92f09425714b9116974b29e58ce8f9ba9 (diff)
Add a container-based version of `std::sample()`
PiperOrigin-RevId: 592864147 Change-Id: I83179b0225aa446ae0b57b46b604af14f1fa14df
Diffstat (limited to 'absl/algorithm')
-rw-r--r--absl/algorithm/container.h30
-rw-r--r--absl/algorithm/container_test.cc22
2 files changed, 51 insertions, 1 deletions
diff --git a/absl/algorithm/container.h b/absl/algorithm/container.h
index 934dd179..c7bafae1 100644
--- a/absl/algorithm/container.h
+++ b/absl/algorithm/container.h
@@ -774,6 +774,36 @@ void c_shuffle(RandomAccessContainer& c, UniformRandomBitGenerator&& gen) {
std::forward<UniformRandomBitGenerator>(gen));
}
+// c_sample()
+//
+// Container-based version of the <algorithm> `std::sample()` function to
+// randomly sample elements from the container without replacement using a
+// `gen()` uniform random number generator and write them to an iterator range.
+template <typename C, typename OutputIterator, typename Distance,
+ typename UniformRandomBitGenerator>
+OutputIterator c_sample(const C& c, OutputIterator result, Distance n,
+ UniformRandomBitGenerator&& gen) {
+#if defined(__cpp_lib_sample) && __cpp_lib_sample >= 201603L
+ return std::sample(container_algorithm_internal::c_begin(c),
+ container_algorithm_internal::c_end(c), result, n,
+ std::forward<UniformRandomBitGenerator>(gen));
+#else
+ // Fall back to a stable selection-sampling implementation.
+ auto first = container_algorithm_internal::c_begin(c);
+ Distance unsampled_elements = c_distance(c);
+ n = (std::min)(n, unsampled_elements);
+ for (; n != 0; ++first) {
+ Distance r =
+ std::uniform_int_distribution<Distance>(0, --unsampled_elements)(gen);
+ if (r < n) {
+ *result++ = *first;
+ --n;
+ }
+ }
+ return result;
+#endif
+}
+
//------------------------------------------------------------------------------
// <algorithm> Partition functions
//------------------------------------------------------------------------------
diff --git a/absl/algorithm/container_test.cc b/absl/algorithm/container_test.cc
index 0fbc7773..c01f5fc0 100644
--- a/absl/algorithm/container_test.cc
+++ b/absl/algorithm/container_test.cc
@@ -14,6 +14,7 @@
#include "absl/algorithm/container.h"
+#include <algorithm>
#include <functional>
#include <initializer_list>
#include <iterator>
@@ -40,8 +41,10 @@ using ::testing::Each;
using ::testing::ElementsAre;
using ::testing::Gt;
using ::testing::IsNull;
+using ::testing::IsSubsetOf;
using ::testing::Lt;
using ::testing::Pointee;
+using ::testing::SizeIs;
using ::testing::Truly;
using ::testing::UnorderedElementsAre;
@@ -963,12 +966,29 @@ TEST(MutatingTest, RotateCopy) {
EXPECT_THAT(actual, ElementsAre(3, 4, 1, 2, 5));
}
+template <typename T>
+T RandomlySeededPrng() {
+ std::random_device rdev;
+ std::seed_seq::result_type data[T::state_size];
+ std::generate_n(data, T::state_size, std::ref(rdev));
+ std::seed_seq prng_seed(data, data + T::state_size);
+ return T(prng_seed);
+}
+
TEST(MutatingTest, Shuffle) {
std::vector<int> actual = {1, 2, 3, 4, 5};
- absl::c_shuffle(actual, std::random_device());
+ absl::c_shuffle(actual, RandomlySeededPrng<std::mt19937_64>());
EXPECT_THAT(actual, UnorderedElementsAre(1, 2, 3, 4, 5));
}
+TEST(MutatingTest, Sample) {
+ std::vector<int> actual;
+ absl::c_sample(std::vector<int>{1, 2, 3, 4, 5}, std::back_inserter(actual), 3,
+ RandomlySeededPrng<std::mt19937_64>());
+ EXPECT_THAT(actual, IsSubsetOf({1, 2, 3, 4, 5}));
+ EXPECT_THAT(actual, SizeIs(3));
+}
+
TEST(MutatingTest, PartialSort) {
std::vector<int> sequence{5, 3, 42, 0};
absl::c_partial_sort(sequence, sequence.begin() + 2);