#include "tensorflow/core/lib/random/weighted_picker.h" #include #include #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/lib/random/simple_philox.h" #include namespace tensorflow { namespace random { static void TestPicker(SimplePhilox* rnd, int size); static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials); static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials); static void TestPickAt(int items, const int32* weights); TEST(WeightedPicker, Simple) { PhiloxRandom philox(testing::RandomSeed(), 17); SimplePhilox rnd(&philox); { VLOG(0) << "======= Zero-length picker"; WeightedPicker picker(0); EXPECT_EQ(picker.Pick(&rnd), -1); } { VLOG(0) << "======= Singleton picker"; WeightedPicker picker(1); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); } { VLOG(0) << "======= Grown picker"; WeightedPicker picker(0); for (int i = 0; i < 10; i++) { picker.Append(1); } CheckUniform(&rnd, &picker, 100000); } { VLOG(0) << "======= Grown picker with zero weights"; WeightedPicker picker(1); picker.Resize(10); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); } { VLOG(0) << "======= Shrink picker and check weights"; WeightedPicker picker(1); picker.Resize(10); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); EXPECT_EQ(picker.Pick(&rnd), 0); for (int i = 0; i < 10; i++) { picker.set_weight(i, i); } EXPECT_EQ(picker.total_weight(), 45); picker.Resize(5); EXPECT_EQ(picker.total_weight(), 10); picker.Resize(2); EXPECT_EQ(picker.total_weight(), 1); picker.Resize(1); EXPECT_EQ(picker.total_weight(), 0); } } TEST(WeightedPicker, BigWeights) { PhiloxRandom philox(testing::RandomSeed() + 1, 17); SimplePhilox rnd(&philox); VLOG(0) << "======= Check uniform with big weights"; WeightedPicker picker(2); picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3 CheckUniform(&rnd, &picker, 100000); } TEST(WeightedPicker, Deterministic) { VLOG(0) << "======= Testing deterministic pick"; static const int32 weights[] = {1, 0, 200, 5, 42}; TestPickAt(TF_ARRAYSIZE(weights), weights); } TEST(WeightedPicker, Randomized) { PhiloxRandom philox(testing::RandomSeed() + 10, 17); SimplePhilox rnd(&philox); TestPicker(&rnd, 1); TestPicker(&rnd, 2); TestPicker(&rnd, 3); TestPicker(&rnd, 4); TestPicker(&rnd, 7); TestPicker(&rnd, 8); TestPicker(&rnd, 9); TestPicker(&rnd, 10); TestPicker(&rnd, 100); } static void TestPicker(SimplePhilox* rnd, int size) { VLOG(0) << "======= Testing size " << size; // Check that empty picker returns -1 { WeightedPicker picker(size); picker.SetAllWeights(0); for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1); } // Create zero weights array std::vector weights(size); for (int elem = 0; elem < size; elem++) { weights[elem] = 0; } // Check that singleton picker always returns the same element for (int elem = 0; elem < size; elem++) { WeightedPicker picker(size); picker.SetAllWeights(0); picker.set_weight(elem, elem + 1); for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); weights[elem] = 10; picker.SetWeightsFromArray(size, &weights[0]); for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); weights[elem] = 0; } // Check that uniform picker generates elements roughly uniformly { WeightedPicker picker(size); CheckUniform(rnd, &picker, 100000); } // Check uniform picker that was grown piecemeal if (size / 3 > 0) { WeightedPicker picker(size / 3); while (picker.num_elements() != size) { picker.Append(1); } CheckUniform(rnd, &picker, 100000); } // Check that skewed distribution works if (size <= 10) { // When picker grows one element at a time WeightedPicker picker(size); int32 weight = 1; for (int elem = 0; elem < size; elem++) { picker.set_weight(elem, weight); weights[elem] = weight; weight *= 2; } CheckSkewed(rnd, &picker, 1000000); // When picker is created from an array WeightedPicker array_picker(0); array_picker.SetWeightsFromArray(size, &weights[0]); CheckSkewed(rnd, &array_picker, 1000000); } } static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials) { const int size = picker->num_elements(); int* count = new int[size]; memset(count, 0, sizeof(count[0]) * size); for (int i = 0; i < size * trials; i++) { const int elem = picker->Pick(rnd); EXPECT_GE(elem, 0); EXPECT_LT(elem, size); count[elem]++; } const int expected_min = int(0.9 * trials); const int expected_max = int(1.1 * trials); for (int i = 0; i < size; i++) { EXPECT_GE(count[i], expected_min); EXPECT_LE(count[i], expected_max); } delete[] count; } static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) { const int size = picker->num_elements(); int* count = new int[size]; memset(count, 0, sizeof(count[0]) * size); for (int i = 0; i < size * trials; i++) { const int elem = picker->Pick(rnd); EXPECT_GE(elem, 0); EXPECT_LT(elem, size); count[elem]++; } for (int i = 0; i < size - 1; i++) { LOG(INFO) << i << ": " << count[i]; const float ratio = float(count[i + 1]) / float(count[i]); EXPECT_GE(ratio, 1.6f); EXPECT_LE(ratio, 2.4f); } delete[] count; } static void TestPickAt(int items, const int32* weights) { WeightedPicker picker(items); picker.SetWeightsFromArray(items, weights); int weight_index = 0; for (int i = 0; i < items; ++i) { for (int j = 0; j < weights[i]; ++j) { int pick = picker.PickAt(weight_index); EXPECT_EQ(pick, i); ++weight_index; } } EXPECT_EQ(weight_index, picker.total_weight()); } static void BM_Create(int iters, int arg) { while (--iters > 0) { WeightedPicker p(arg); } } BENCHMARK(BM_Create)->Range(1, 1024); static void BM_CreateAndSetWeights(int iters, int arg) { std::vector weights(arg); for (int i = 0; i < arg; i++) { weights[i] = i * 10; } while (--iters > 0) { WeightedPicker p(arg); p.SetWeightsFromArray(arg, &weights[0]); } } BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024); static void BM_Pick(int iters, int arg) { PhiloxRandom philox(301, 17); SimplePhilox rnd(&philox); WeightedPicker p(arg); int result = 0; while (--iters > 0) { result += p.Pick(&rnd); } VLOG(4) << result; // Dummy use } BENCHMARK(BM_Pick)->Range(1, 1024); } // namespace random } // namespace tensorflow