From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/lib/random/weighted_picker_test.cc | 254 +++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 tensorflow/core/lib/random/weighted_picker_test.cc (limited to 'tensorflow/core/lib/random/weighted_picker_test.cc') diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc new file mode 100644 index 0000000000..0b27d437d5 --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker_test.cc @@ -0,0 +1,254 @@ +#include "tensorflow/core/lib/random/weighted_picker.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include + +namespace tensorflow { +namespace random { + +static void TestPicker(SimplePhilox* rnd, int size); +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void TestPickAt(int items, const int32* weights); + +TEST(WeightedPicker, Simple) { + PhiloxRandom philox(testing::RandomSeed(), 17); + SimplePhilox rnd(&philox); + + { + VLOG(0) << "======= Zero-length picker"; + WeightedPicker picker(0); + EXPECT_EQ(picker.Pick(&rnd), -1); + } + + { + VLOG(0) << "======= Singleton picker"; + WeightedPicker picker(1); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Grown picker"; + WeightedPicker picker(0); + for (int i = 0; i < 10; i++) { + picker.Append(1); + } + CheckUniform(&rnd, &picker, 100000); + } + + { + VLOG(0) << "======= Grown picker with zero weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Shrink picker and check weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + for (int i = 0; i < 10; i++) { + picker.set_weight(i, i); + } + EXPECT_EQ(picker.total_weight(), 45); + picker.Resize(5); + EXPECT_EQ(picker.total_weight(), 10); + picker.Resize(2); + EXPECT_EQ(picker.total_weight(), 1); + picker.Resize(1); + EXPECT_EQ(picker.total_weight(), 0); + } +} + +TEST(WeightedPicker, BigWeights) { + PhiloxRandom philox(testing::RandomSeed() + 1, 17); + SimplePhilox rnd(&philox); + VLOG(0) << "======= Check uniform with big weights"; + WeightedPicker picker(2); + picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3 + CheckUniform(&rnd, &picker, 100000); +} + +TEST(WeightedPicker, Deterministic) { + VLOG(0) << "======= Testing deterministic pick"; + static const int32 weights[] = {1, 0, 200, 5, 42}; + TestPickAt(TF_ARRAYSIZE(weights), weights); +} + +TEST(WeightedPicker, Randomized) { + PhiloxRandom philox(testing::RandomSeed() + 10, 17); + SimplePhilox rnd(&philox); + TestPicker(&rnd, 1); + TestPicker(&rnd, 2); + TestPicker(&rnd, 3); + TestPicker(&rnd, 4); + TestPicker(&rnd, 7); + TestPicker(&rnd, 8); + TestPicker(&rnd, 9); + TestPicker(&rnd, 10); + TestPicker(&rnd, 100); +} + +static void TestPicker(SimplePhilox* rnd, int size) { + VLOG(0) << "======= Testing size " << size; + + // Check that empty picker returns -1 + { + WeightedPicker picker(size); + picker.SetAllWeights(0); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1); + } + + // Create zero weights array + std::vector weights(size); + for (int elem = 0; elem < size; elem++) { + weights[elem] = 0; + } + + // Check that singleton picker always returns the same element + for (int elem = 0; elem < size; elem++) { + WeightedPicker picker(size); + picker.SetAllWeights(0); + picker.set_weight(elem, elem + 1); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 10; + picker.SetWeightsFromArray(size, &weights[0]); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 0; + } + + // Check that uniform picker generates elements roughly uniformly + { + WeightedPicker picker(size); + CheckUniform(rnd, &picker, 100000); + } + + // Check uniform picker that was grown piecemeal + if (size / 3 > 0) { + WeightedPicker picker(size / 3); + while (picker.num_elements() != size) { + picker.Append(1); + } + CheckUniform(rnd, &picker, 100000); + } + + // Check that skewed distribution works + if (size <= 10) { + // When picker grows one element at a time + WeightedPicker picker(size); + int32 weight = 1; + for (int elem = 0; elem < size; elem++) { + picker.set_weight(elem, weight); + weights[elem] = weight; + weight *= 2; + } + CheckSkewed(rnd, &picker, 1000000); + + // When picker is created from an array + WeightedPicker array_picker(0); + array_picker.SetWeightsFromArray(size, &weights[0]); + CheckSkewed(rnd, &array_picker, 1000000); + } +} + +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, + int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + const int expected_min = int(0.9 * trials); + const int expected_max = int(1.1 * trials); + for (int i = 0; i < size; i++) { + EXPECT_GE(count[i], expected_min); + EXPECT_LE(count[i], expected_max); + } + delete[] count; +} + +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + + for (int i = 0; i < size - 1; i++) { + LOG(INFO) << i << ": " << count[i]; + const float ratio = float(count[i + 1]) / float(count[i]); + EXPECT_GE(ratio, 1.6f); + EXPECT_LE(ratio, 2.4f); + } + delete[] count; +} + +static void TestPickAt(int items, const int32* weights) { + WeightedPicker picker(items); + picker.SetWeightsFromArray(items, weights); + int weight_index = 0; + for (int i = 0; i < items; ++i) { + for (int j = 0; j < weights[i]; ++j) { + int pick = picker.PickAt(weight_index); + EXPECT_EQ(pick, i); + ++weight_index; + } + } + EXPECT_EQ(weight_index, picker.total_weight()); +} + +static void BM_Create(int iters, int arg) { + while (--iters > 0) { + WeightedPicker p(arg); + } +} +BENCHMARK(BM_Create)->Range(1, 1024); + +static void BM_CreateAndSetWeights(int iters, int arg) { + std::vector weights(arg); + for (int i = 0; i < arg; i++) { + weights[i] = i * 10; + } + while (--iters > 0) { + WeightedPicker p(arg); + p.SetWeightsFromArray(arg, &weights[0]); + } +} +BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024); + +static void BM_Pick(int iters, int arg) { + PhiloxRandom philox(301, 17); + SimplePhilox rnd(&philox); + WeightedPicker p(arg); + int result = 0; + while (--iters > 0) { + result += p.Pick(&rnd); + } + VLOG(4) << result; // Dummy use +} +BENCHMARK(BM_Pick)->Range(1, 1024); + +} // namespace random +} // namespace tensorflow -- cgit v1.2.3