aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/weighted_picker_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random/weighted_picker_test.cc')
-rw-r--r--tensorflow/core/lib/random/weighted_picker_test.cc254
1 files changed, 254 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc
new file mode 100644
index 0000000000..0b27d437d5
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker_test.cc
@@ -0,0 +1,254 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+static void TestPicker(SimplePhilox* rnd, int size);
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void TestPickAt(int items, const int32* weights);
+
+TEST(WeightedPicker, Simple) {
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox rnd(&philox);
+
+ {
+ VLOG(0) << "======= Zero-length picker";
+ WeightedPicker picker(0);
+ EXPECT_EQ(picker.Pick(&rnd), -1);
+ }
+
+ {
+ VLOG(0) << "======= Singleton picker";
+ WeightedPicker picker(1);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker";
+ WeightedPicker picker(0);
+ for (int i = 0; i < 10; i++) {
+ picker.Append(1);
+ }
+ CheckUniform(&rnd, &picker, 100000);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker with zero weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Shrink picker and check weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ for (int i = 0; i < 10; i++) {
+ picker.set_weight(i, i);
+ }
+ EXPECT_EQ(picker.total_weight(), 45);
+ picker.Resize(5);
+ EXPECT_EQ(picker.total_weight(), 10);
+ picker.Resize(2);
+ EXPECT_EQ(picker.total_weight(), 1);
+ picker.Resize(1);
+ EXPECT_EQ(picker.total_weight(), 0);
+ }
+}
+
+TEST(WeightedPicker, BigWeights) {
+ PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ SimplePhilox rnd(&philox);
+ VLOG(0) << "======= Check uniform with big weights";
+ WeightedPicker picker(2);
+ picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3
+ CheckUniform(&rnd, &picker, 100000);
+}
+
+TEST(WeightedPicker, Deterministic) {
+ VLOG(0) << "======= Testing deterministic pick";
+ static const int32 weights[] = {1, 0, 200, 5, 42};
+ TestPickAt(TF_ARRAYSIZE(weights), weights);
+}
+
+TEST(WeightedPicker, Randomized) {
+ PhiloxRandom philox(testing::RandomSeed() + 10, 17);
+ SimplePhilox rnd(&philox);
+ TestPicker(&rnd, 1);
+ TestPicker(&rnd, 2);
+ TestPicker(&rnd, 3);
+ TestPicker(&rnd, 4);
+ TestPicker(&rnd, 7);
+ TestPicker(&rnd, 8);
+ TestPicker(&rnd, 9);
+ TestPicker(&rnd, 10);
+ TestPicker(&rnd, 100);
+}
+
+static void TestPicker(SimplePhilox* rnd, int size) {
+ VLOG(0) << "======= Testing size " << size;
+
+ // Check that empty picker returns -1
+ {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1);
+ }
+
+ // Create zero weights array
+ std::vector<int32> weights(size);
+ for (int elem = 0; elem < size; elem++) {
+ weights[elem] = 0;
+ }
+
+ // Check that singleton picker always returns the same element
+ for (int elem = 0; elem < size; elem++) {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ picker.set_weight(elem, elem + 1);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 10;
+ picker.SetWeightsFromArray(size, &weights[0]);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 0;
+ }
+
+ // Check that uniform picker generates elements roughly uniformly
+ {
+ WeightedPicker picker(size);
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check uniform picker that was grown piecemeal
+ if (size / 3 > 0) {
+ WeightedPicker picker(size / 3);
+ while (picker.num_elements() != size) {
+ picker.Append(1);
+ }
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check that skewed distribution works
+ if (size <= 10) {
+ // When picker grows one element at a time
+ WeightedPicker picker(size);
+ int32 weight = 1;
+ for (int elem = 0; elem < size; elem++) {
+ picker.set_weight(elem, weight);
+ weights[elem] = weight;
+ weight *= 2;
+ }
+ CheckSkewed(rnd, &picker, 1000000);
+
+ // When picker is created from an array
+ WeightedPicker array_picker(0);
+ array_picker.SetWeightsFromArray(size, &weights[0]);
+ CheckSkewed(rnd, &array_picker, 1000000);
+ }
+}
+
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker,
+ int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+ const int expected_min = int(0.9 * trials);
+ const int expected_max = int(1.1 * trials);
+ for (int i = 0; i < size; i++) {
+ EXPECT_GE(count[i], expected_min);
+ EXPECT_LE(count[i], expected_max);
+ }
+ delete[] count;
+}
+
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+
+ for (int i = 0; i < size - 1; i++) {
+ LOG(INFO) << i << ": " << count[i];
+ const float ratio = float(count[i + 1]) / float(count[i]);
+ EXPECT_GE(ratio, 1.6f);
+ EXPECT_LE(ratio, 2.4f);
+ }
+ delete[] count;
+}
+
+static void TestPickAt(int items, const int32* weights) {
+ WeightedPicker picker(items);
+ picker.SetWeightsFromArray(items, weights);
+ int weight_index = 0;
+ for (int i = 0; i < items; ++i) {
+ for (int j = 0; j < weights[i]; ++j) {
+ int pick = picker.PickAt(weight_index);
+ EXPECT_EQ(pick, i);
+ ++weight_index;
+ }
+ }
+ EXPECT_EQ(weight_index, picker.total_weight());
+}
+
+static void BM_Create(int iters, int arg) {
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ }
+}
+BENCHMARK(BM_Create)->Range(1, 1024);
+
+static void BM_CreateAndSetWeights(int iters, int arg) {
+ std::vector<int32> weights(arg);
+ for (int i = 0; i < arg; i++) {
+ weights[i] = i * 10;
+ }
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ p.SetWeightsFromArray(arg, &weights[0]);
+ }
+}
+BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024);
+
+static void BM_Pick(int iters, int arg) {
+ PhiloxRandom philox(301, 17);
+ SimplePhilox rnd(&philox);
+ WeightedPicker p(arg);
+ int result = 0;
+ while (--iters > 0) {
+ result += p.Pick(&rnd);
+ }
+ VLOG(4) << result; // Dummy use
+}
+BENCHMARK(BM_Pick)->Range(1, 1024);
+
+} // namespace random
+} // namespace tensorflow