aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/range_sampler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/range_sampler_test.cc')
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc320
1 files changed, 320 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
new file mode 100644
index 0000000000..72c39009e4
--- /dev/null
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -0,0 +1,320 @@
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/kernels/range_sampler.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace {
+
+using gtl::ArraySlice;
+using gtl::MutableArraySlice;
+
+class RangeSamplerTest : public ::testing::Test {
+ protected:
+ void CheckProbabilitiesSumToOne() {
+ double sum = 0;
+ for (int i = 0; i < sampler_->range(); i++) {
+ sum += sampler_->Probability(i);
+ }
+ EXPECT_NEAR(sum, 1.0, 1e-4);
+ }
+ void CheckHistogram(int num_samples, float tolerance) {
+ const int range = sampler_->range();
+ std::vector<int> h(range);
+ std::vector<int64> a(num_samples);
+ // Using a fixed random seed to make the test deterministic.
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ sampler_->SampleBatch(&rnd, false, &a);
+ for (int i = 0; i < num_samples; i++) {
+ int64 val = a[i];
+ ASSERT_GE(val, 0);
+ ASSERT_LT(val, range);
+ h[val]++;
+ }
+ for (int val = 0; val < range; val++) {
+ EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
+ tolerance);
+ }
+ }
+ void Update1() {
+ // Add the value 3 ten times.
+ std::vector<int64> a(10);
+ for (int i = 0; i < 10; i++) {
+ a[i] = 3;
+ }
+ sampler_->Update(a);
+ }
+ void Update2() {
+ // Add the value n n times.
+ int64 a[10];
+ for (int i = 0; i < 10; i++) {
+ a[i] = i;
+ }
+ for (int64 i = 1; i < 10; i++) {
+ sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
+ }
+ }
+ std::unique_ptr<RangeSampler> sampler_;
+};
+
+TEST_F(RangeSamplerTest, UniformProbabilities) {
+ sampler_.reset(new UniformSampler(10));
+ for (int i = 0; i < 10; i++) {
+ CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
+ }
+}
+
+TEST_F(RangeSamplerTest, UniformChecksum) {
+ sampler_.reset(new UniformSampler(10));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, UniformHistogram) {
+ sampler_.reset(new UniformSampler(10));
+ CheckHistogram(1000, 0.05);
+}
+
+TEST_F(RangeSamplerTest, LogUniformProbabilities) {
+ int range = 1000000;
+ sampler_.reset(new LogUniformSampler(range));
+ for (int i = 100; i < range; i *= 2) {
+ float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
+ EXPECT_NEAR(ratio, 0.5, 0.1);
+ }
+}
+
+TEST_F(RangeSamplerTest, LogUniformChecksum) {
+ sampler_.reset(new LogUniformSampler(10));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, LogUniformHistogram) {
+ sampler_.reset(new LogUniformSampler(10));
+ CheckHistogram(1000, 0.05);
+}
+
+TEST_F(RangeSamplerTest, UnigramProbabilities1) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
+ for (int i = 0; i < 10; i++) {
+ if (i != 3) {
+ ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
+ }
+ }
+}
+TEST_F(RangeSamplerTest, UnigramProbabilities2) {
+ sampler_.reset(new UnigramSampler(10));
+ Update2();
+ for (int i = 0; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, UnigramChecksum) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, UnigramHistogram) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ CheckHistogram(1000, 0.05);
+}
+
+static const char kVocabContent[] =
+ "w1,1\n"
+ "w2,2\n"
+ "w3,4\n"
+ "w4,8\n"
+ "w5,16\n"
+ "w6,32\n"
+ "w7,64\n"
+ "w8,128\n"
+ "w9,256";
+TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 0; i < 9; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ CheckHistogram(1000, 0.05);
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 1; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 2; i < 11; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 0; i < 9; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ CheckProbabilitiesSumToOne();
+}
+TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ CheckHistogram(1000, 0.05);
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 1; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 2; i < 11; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
+ }
+}
+
+// AllSampler cannot call Sample or Probability directly.
+// We will test SampleBatchGetExpectedCount instead.
+TEST_F(RangeSamplerTest, All) {
+ int batch_size = 10;
+ sampler_.reset(new AllSampler(10));
+ std::vector<int64> batch(batch_size);
+ std::vector<float> batch_expected(batch_size);
+ std::vector<int64> extras(2);
+ std::vector<float> extras_expected(2);
+ extras[0] = 0;
+ extras[1] = batch_size - 1;
+ sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
+ false, &batch, &batch_expected, extras,
+ &extras_expected);
+ for (int i = 0; i < batch_size; i++) {
+ EXPECT_EQ(i, batch[i]);
+ EXPECT_EQ(1, batch_expected[i]);
+ }
+ EXPECT_EQ(1, extras_expected[0]);
+ EXPECT_EQ(1, extras_expected[1]);
+}
+
+TEST_F(RangeSamplerTest, Unique) {
+ // We sample num_batches batches, each without replacement.
+ //
+ // We check that the returned expected counts roughly agree with each other
+ // and with the average observed frequencies over the set of batches.
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ const int range = 100;
+ const int batch_size = 50;
+ const int num_batches = 100;
+ sampler_.reset(new LogUniformSampler(range));
+ std::vector<int> histogram(range);
+ std::vector<int64> batch(batch_size);
+ std::vector<int64> all_values(range);
+ for (int i = 0; i < range; i++) {
+ all_values[i] = i;
+ }
+ std::vector<float> expected(range);
+
+ // Sample one batch and get the expected counts of all values
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ // Check that all elements are unique
+ std::set<int64> s(batch.begin(), batch.end());
+ CHECK_EQ(batch_size, s.size());
+
+ for (int trial = 0; trial < num_batches; trial++) {
+ std::vector<float> trial_expected(range);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
+ MutableArraySlice<float>(),
+ all_values, &trial_expected);
+ for (int i = 0; i < range; i++) {
+ EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
+ }
+ for (int i = 0; i < batch_size; i++) {
+ histogram[batch[i]]++;
+ }
+ }
+ for (int i = 0; i < range; i++) {
+ // Check that the computed expected count agrees with the average observed
+ // count.
+ const float average_count = static_cast<float>(histogram[i]) / num_batches;
+ EXPECT_NEAR(expected[i], average_count, 0.2);
+ }
+}
+
+TEST_F(RangeSamplerTest, Avoid) {
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ sampler_.reset(new LogUniformSampler(100));
+ std::vector<int64> avoided(2);
+ avoided[0] = 17;
+ avoided[1] = 23;
+ std::vector<int64> batch(98);
+
+ // We expect to pick all elements of [0, 100) except the avoided two.
+ sampler_->SampleBatchGetExpectedCountAvoid(
+ &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
+ MutableArraySlice<float>(), avoided);
+
+ int sum = 0;
+ for (auto val : batch) {
+ sum += val;
+ }
+ const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
+ EXPECT_EQ(expected_sum, sum);
+}
+
+} // namespace
+
+} // namespace tensorflow