aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random')
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.cc80
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.h79
-rw-r--r--tensorflow/core/lib/random/distribution_sampler_test.cc90
-rw-r--r--tensorflow/core/lib/random/exact_uniform_int.h68
-rw-r--r--tensorflow/core/lib/random/philox_random.h232
-rw-r--r--tensorflow/core/lib/random/philox_random_test.cc58
-rw-r--r--tensorflow/core/lib/random/philox_random_test_utils.h36
-rw-r--r--tensorflow/core/lib/random/random.cc22
-rw-r--r--tensorflow/core/lib/random/random.h16
-rw-r--r--tensorflow/core/lib/random/random_distributions.h361
-rw-r--r--tensorflow/core/lib/random/random_distributions_test.cc270
-rw-r--r--tensorflow/core/lib/random/random_test.cc21
-rw-r--r--tensorflow/core/lib/random/simple_philox.cc24
-rw-r--r--tensorflow/core/lib/random/simple_philox.h61
-rw-r--r--tensorflow/core/lib/random/simple_philox_test.cc120
-rw-r--r--tensorflow/core/lib/random/weighted_picker.cc203
-rw-r--r--tensorflow/core/lib/random/weighted_picker.h118
-rw-r--r--tensorflow/core/lib/random/weighted_picker_test.cc254
18 files changed, 2113 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc
new file mode 100644
index 0000000000..341f1bd595
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.cc
@@ -0,0 +1,80 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <memory>
+#include <vector>
+
+namespace tensorflow {
+namespace random {
+
+DistributionSampler::DistributionSampler(
+ const gtl::ArraySlice<float>& weights) {
+ DCHECK(!weights.empty());
+ int n = weights.size();
+ num_ = n;
+ data_.reset(new std::pair<float, int>[n]);
+
+ std::unique_ptr<double[]> pr(new double[n]);
+
+ double sum = 0.0;
+ for (int i = 0; i < n; i++) {
+ sum += weights[i];
+ set_alt(i, -1);
+ }
+
+ // These are long/short items - called high/low because of reserved keywords.
+ std::vector<int> high;
+ high.reserve(n);
+ std::vector<int> low;
+ low.reserve(n);
+
+ // compute propotional weights
+ for (int i = 0; i < n; i++) {
+ double p = (weights[i] * n) / sum;
+ pr[i] = p;
+ if (p < 1.0) {
+ low.push_back(i);
+ } else {
+ high.push_back(i);
+ }
+ }
+
+ // Now pair high with low.
+ while (!high.empty() && !low.empty()) {
+ int l = low.back();
+ low.pop_back();
+ int h = high.back();
+ high.pop_back();
+
+ set_alt(l, h);
+ DCHECK_GE(pr[h], 1.0);
+ double remaining = pr[h] - (1.0 - pr[l]);
+ pr[h] = remaining;
+
+ if (remaining < 1.0) {
+ low.push_back(h);
+ } else {
+ high.push_back(h);
+ }
+ }
+ // Transfer pr to prob with rounding errors.
+ for (int i = 0; i < n; i++) {
+ set_prob(i, pr[i]);
+ }
+ // Because of rounding errors, both high and low may have elements, that are
+ // close to 1.0 prob.
+ for (size_t i = 0; i < high.size(); i++) {
+ int idx = high[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+ for (size_t i = 0; i < low.size(); i++) {
+ int idx = low[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h
new file mode 100644
index 0000000000..ab9598a205
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.h
@@ -0,0 +1,79 @@
+// DistributionSampler allows generating a discrete random variable with a given
+// distribution.
+// The values taken by the variable are [0, N) and relative weights for each
+// value are specified using a vector of size N.
+//
+// The Algorithm takes O(N) time to precompute data at construction time and
+// takes O(1) time (2 random number generation, 2 lookups) for each sample.
+// The data structure takes O(N) memory.
+//
+// In contrast, util/random/weighted-picker.h provides O(lg N) sampling.
+// The advantage of that implementation is that weights can be adjusted
+// dynamically, while DistributionSampler doesn't allow weight adjustment.
+//
+// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
+
+#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSampler {
+ public:
+ explicit DistributionSampler(const gtl::ArraySlice<float>& weights);
+
+ ~DistributionSampler() {}
+
+ int Sample(SimplePhilox* rand) const {
+ float r = rand->RandFloat();
+ // Since n is typically low, we don't bother with UnbiasedUniform.
+ int idx = rand->Uniform(num_);
+ if (r < prob(idx)) return idx;
+ // else pick alt from that bucket.
+ DCHECK_NE(-1, alt(idx));
+ return alt(idx);
+ }
+
+ int num() const { return num_; }
+
+ private:
+ float prob(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].first;
+ }
+
+ int alt(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].second;
+ }
+
+ void set_prob(int idx, float f) {
+ DCHECK_LT(idx, num_);
+ data_[idx].first = f;
+ }
+
+ void set_alt(int idx, int val) {
+ DCHECK_LT(idx, num_);
+ data_[idx].second = val;
+ }
+
+ int num_;
+ std::unique_ptr<std::pair<float, int>[]> data_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler);
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
diff --git a/tensorflow/core/lib/random/distribution_sampler_test.cc b/tensorflow/core/lib/random/distribution_sampler_test.cc
new file mode 100644
index 0000000000..d61a8daa0f
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler_test.cc
@@ -0,0 +1,90 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <string.h>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSamplerTest : public ::testing::Test {
+ protected:
+ // Returns the Chi-Squared statistic for the two distributions.
+ float TestWeights(const std::vector<float>& weights, int trials_per_bin) {
+ int iters = weights.size() * trials_per_bin;
+ std::unique_ptr<float[]> counts(new float[weights.size()]);
+ memset(counts.get(), 0, sizeof(float) * weights.size());
+ DistributionSampler sampler(weights);
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox random(&philox);
+ for (int i = 0; i < iters; i++) {
+ int r = sampler.Sample(&random);
+ EXPECT_LT(r, weights.size());
+ EXPECT_GE(r, 0);
+ counts[r] += 1.0;
+ }
+ float chi2 = 0.0;
+ for (size_t i = 0; i < weights.size(); i++) {
+ counts[i] /= iters;
+ float err = (counts[i] - weights[i]);
+ chi2 += (err * err) / weights[i];
+ }
+ return chi2;
+ }
+
+ void TestDistribution(float* arr, int n) {
+ std::vector<float> w;
+ w.reserve(n);
+ for (int i = 0; i < n; i++) {
+ w.push_back(arr[i]);
+ }
+ float var = TestWeights(w, 1000);
+ if (var < 0.001) return;
+ // Maybe a statistical skew. Let's try more iterations.
+ var = TestWeights(w, 100000);
+ if (var < 0.001) return;
+ EXPECT_TRUE(false) << "Chi2 is " << var << " in " << n * 100000
+ << "iterations";
+ }
+};
+
+TEST_F(DistributionSamplerTest, KnownDistribution) {
+ float kEven2[] = {0.5, 0.5};
+ float kEven3[] = {0.33333333, 0.33333333, 0.33333333};
+ float kEven4[] = {0.25, 0.25, 0.25, 0.25};
+
+ float kDist1[] = {0.8, 0.15, 0.05};
+
+ TestDistribution(kEven2, TF_ARRAYSIZE(kEven2));
+ TestDistribution(kEven3, TF_ARRAYSIZE(kEven3));
+ TestDistribution(kEven4, TF_ARRAYSIZE(kEven4));
+ TestDistribution(kDist1, TF_ARRAYSIZE(kDist1));
+}
+
+static void BM_DistributionSampler(int iters, int n) {
+ testing::StopTiming();
+ PhiloxRandom philox(173, 371);
+ SimplePhilox rand(&philox);
+ std::vector<float> weights(n, 0);
+ for (int i = 0; i < n; i++) {
+ weights[i] = rand.Uniform(100);
+ }
+ DistributionSampler picker(weights);
+ testing::StartTiming();
+ int r = 0;
+ for (int i = 0; i < iters; i++) {
+ r |= picker.Sample(&rand);
+ }
+ CHECK_NE(r, kint32max);
+}
+
+BENCHMARK(BM_DistributionSampler)->Arg(10)->Arg(100)->Arg(1000);
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h
new file mode 100644
index 0000000000..616354cc5c
--- /dev/null
+++ b/tensorflow/core/lib/random/exact_uniform_int.h
@@ -0,0 +1,68 @@
+// Exact uniform integers using rejection sampling
+
+#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+
+#include <type_traits>
+
+namespace tensorflow {
+namespace random {
+
+template <typename UintType, typename RandomBits>
+UintType ExactUniformInt(const UintType n, const RandomBits& random) {
+ static_assert(std::is_unsigned<UintType>::value,
+ "UintType must be an unsigned int");
+ static_assert(std::is_same<UintType, decltype(random())>::value,
+ "random() should return UintType");
+ if (n == 0) {
+ // Consume a value anyway
+ // TODO(irving): Assert n != 0, since this case makes no sense.
+ return random() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return random() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // random's output is uniform in the half-open interval [0, 2^{bits}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ const UintType range = ~static_cast<UintType>(0);
+ const UintType rem = (range % n) + 1;
+ UintType rnd;
+
+ // rem = ((2^bits-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = random(); // rnd uniform over [0, 2^{bits})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{bits})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1
+ // = 2^{bits}-1 - ((2^{bits}-1) \bmod n)
+ // = n \cdot \lfloor (2^{bits}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{bits}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
new file mode 100644
index 0000000000..2c3cd0c4b9
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -0,0 +1,232 @@
+// Implement the Philox algorithm to generate random numbers in parallel.
+// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
+// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
+
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/port.h"
+
+// Function qualifiers that need to work on both CPU and GPU.
+#ifdef __CUDA_ARCH__
+// For nvcc.
+#define PHILOX_DEVICE_FUNC __host__ __device__
+#define PHILOX_INLINE __inline__
+#else
+// For non-nvcc.
+#define PHILOX_DEVICE_FUNC
+#define PHILOX_INLINE inline
+#endif
+#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE
+
+#include <math.h>
+
+namespace tensorflow {
+namespace random {
+
+// A class that represents an inline array. It can be used on both CPU and GPU,
+// and also trivially copyable between CPU and GPU.
+// Arguments:
+// T: the array element type;
+// ElementCount: the fixed size of the array;
+template <typename T, int ElementCount>
+class Array {
+ public:
+ PHILOX_DEVICE_INLINE Array() {
+ for (int i = 0; i < ElementCount; ++i) {
+ data_[i] = T();
+ }
+ }
+
+ PHILOX_DEVICE_INLINE const T& operator[](int index) const {
+ return data_[index];
+ }
+
+ PHILOX_DEVICE_INLINE T& operator[](int index) { return data_[index]; }
+
+ size_t size() const { return ElementCount; }
+
+ private:
+ T data_[ElementCount];
+};
+
+// A class that encapsulates all the states for a random number generator using
+// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits
+// in the form of four uint32.
+// There are multiple variants of this algorithm, we picked the 4x32_10 version
+// that is most suited for our applications.
+// Since this class is meant to be copied between CPU to GPU, it maintains a
+// value semantics.
+//
+// For example: To use this class and populate an array of 1024 randoms on CPU
+// with two threads,
+//
+// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) {
+// assert(start % 4 == 0);
+// assert(limit % 4 == 0);
+// rnd.Skip(start / 4);
+// for (int i = start; i < limit; i += 4) {
+// auto sample = rnd();
+// ... copy sample[0..3] to output[i..i+3]
+// }
+// }
+//
+// PhiloxRandom rng(seed);
+// PhiloxRandom rng_copy = rng;
+// rng.Skip(1000/4);
+//
+// ... schedule Fill(rng_copy, output, 0, 512) in thread 1;
+// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2;
+// ... wait for thread 1 & 2 to finish executing Fill().
+//
+// NOTE:
+// 1. PhiloxRandom is trivially copyable.
+// 2. PhiloxRandom is compilable by gcc and nvcc.
+class PhiloxRandom {
+ public:
+ typedef Array<uint32, 4> ResultType;
+ typedef uint32 ResultElementType;
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 4;
+
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom() {}
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed) {
+ key_[0] = static_cast<uint32>(seed);
+ key_[1] = static_cast<uint32>(seed >> 32);
+ }
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed_lo, uint64 seed_hi) {
+ key_[0] = static_cast<uint32>(seed_lo);
+ key_[1] = static_cast<uint32>(seed_lo >> 32);
+ counter_[2] = static_cast<uint32>(seed_hi);
+ counter_[3] = static_cast<uint32>(seed_hi >> 32);
+ }
+
+ // Skip the specified number of samples of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE
+ void Skip(uint64 count) {
+ const uint32 count_lo = static_cast<uint32>(count);
+ uint32 count_hi = static_cast<uint32>(count >> 32);
+
+ counter_[0] += count_lo;
+ if (counter_[0] < count_lo) {
+ ++count_hi;
+ }
+
+ counter_[1] += count_hi;
+ if (counter_[1] < count_hi) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+
+ // Returns a group of four random numbers using the underlying Philox
+ // algorithm.
+ PHILOX_DEVICE_INLINE ResultType operator()() {
+ ResultType counter = counter_;
+ Key key = key_;
+
+ // Run the single rounds for ten times. Manually unrolling the loop
+ // for better performance.
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+
+ SkipOne();
+
+ return counter;
+ }
+
+ private:
+ // The type for the 64-bit key stored in the form of two 32-bit uint
+ // that are used in the diffusion process.
+ typedef Array<uint32, 2> Key;
+
+ // We use the same constants as recommended by the original paper.
+ static const uint32 kPhiloxW32A = 0x9E3779B9;
+ static const uint32 kPhiloxW32B = 0xBB67AE85;
+ static const uint32 kPhiloxM4x32A = 0xD2511F53;
+ static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
+
+ // Helper function to skip the next sample of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE void SkipOne() {
+ if (++counter_[0] == 0) {
+ if (++counter_[1] == 0) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+ }
+
+ // Helper function to return the lower and higher 32-bits from two 32-bit
+ // integer multiplications.
+ PHILOX_DEVICE_INLINE
+ static void MultiplyHighLow(uint32 a, uint32 b, uint32* result_low,
+ uint32* result_high) {
+#ifndef __GCUDACC__
+ const uint64 product = static_cast<uint64>(a) * b;
+ *result_low = static_cast<uint32>(product);
+ *result_high = static_cast<uint32>(product >> 32);
+#else
+ *result_low = a * b;
+ *result_high = __umulhi(a, b);
+#endif
+ }
+
+ // Helper function for a single round of the underlying Philox algorithm.
+ PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound(
+ const ResultType& counter, const Key& key) {
+ uint32 lo0;
+ uint32 hi0;
+ MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
+
+ uint32 lo1;
+ uint32 hi1;
+ MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
+
+ ResultType result;
+ result[0] = hi1 ^ counter[1] ^ key[0];
+ result[1] = lo1;
+ result[2] = hi0 ^ counter[3] ^ key[1];
+ result[3] = lo0;
+ return result;
+ }
+
+ PHILOX_DEVICE_INLINE void RaiseKey(Key* key) {
+ (*key)[0] += kPhiloxW32A;
+ (*key)[1] += kPhiloxW32B;
+ }
+
+ private:
+ ResultType counter_;
+ Key key_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc
new file mode 100644
index 0000000000..997c0263b7
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test.cc
@@ -0,0 +1,58 @@
+#include "tensorflow/core/lib/random/philox_random.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// A trivial distribution that just returns the PhiloxRandom as a distribution
+class TrivialPhiloxDistribution {
+ public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount;
+ typedef PhiloxRandom::ResultType ResultType;
+ typedef PhiloxRandom::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(PhiloxRandom* gen) { return (*gen)(); }
+};
+
+// This test checks that skipping certain number of samples, is equivalent to
+// generate the same number of samples without skipping.
+TEST(PhiloxRandomTest, SkipMatchTest) {
+ constexpr int count = 1024;
+ constexpr int skip_count = 2048;
+
+ uint64 test_seed = GetTestSeed();
+ std::vector<uint32> v1(count);
+ {
+ PhiloxRandom gen(test_seed);
+ gen.Skip(skip_count / 4);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v1[0], v1.size());
+ }
+
+ std::vector<uint32> v2(count + skip_count);
+ {
+ PhiloxRandom gen(test_seed);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v2[0], v2.size());
+ }
+
+ for (int i = 0; i < count; ++i) {
+ ASSERT_EQ(v1[i], v2[i + skip_count]);
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h
new file mode 100644
index 0000000000..d22f6b36e4
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test_utils.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a random seed.
+inline uint64 GetTestSeed() { return New64(); }
+
+// A utility function to fill the given array with samples from the given
+// distribution.
+template <class Distribution>
+void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
+ int64 size) {
+ const int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ const auto sample = dist(&gen);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc
new file mode 100644
index 0000000000..2959b05382
--- /dev/null
+++ b/tensorflow/core/lib/random/random.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <random>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+std::mt19937_64* InitRng() {
+ std::random_device device("/dev/random");
+ return new std::mt19937_64(device());
+}
+
+uint64 New64() {
+ static std::mt19937_64* rng = InitRng();
+ static mutex mu;
+ mutex_lock l(mu);
+ return (*rng)();
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h
new file mode 100644
index 0000000000..1a20436c4e
--- /dev/null
+++ b/tensorflow/core/lib/random/random.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a 64-bit random value. Different sequences are generated
+// in different processes.
+uint64 New64();
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
new file mode 100644
index 0000000000..caafcde513
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -0,0 +1,361 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+
+namespace tensorflow {
+namespace random {
+
+// Helper function to convert a 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1);
+
+// A class that generates uniform distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class UniformDistribution;
+
+template <class Generator>
+class UniformDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint32ToFloat(sample[i]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class UniformDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that adapts the underlying native multiple samples to return a single
+// sample at a time.
+template <class Generator>
+class SingleSampleAdapter {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 1;
+ // The number of elements that will be returned by the underlying generator.
+ static const int kNativeElementCount = Generator::kResultElementCount;
+ typedef typename Generator::ResultElementType ResultType;
+ typedef typename Generator::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ explicit SingleSampleAdapter(Generator* gen)
+ : generator_(gen), used_result_index_(Generator::kResultElementCount) {}
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()() {
+ if (used_result_index_ == Generator::kResultElementCount) {
+ unused_results_ = (*generator_)();
+ used_result_index_ = 0;
+ }
+
+ return unused_results_[used_result_index_++];
+ }
+
+ private:
+ Generator* generator_;
+ typename Generator::ResultType unused_results_;
+ int used_result_index_;
+};
+
+// A class that generates unit normal distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class NormalDistribution;
+
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1);
+
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1);
+
+template <class Generator>
+class NormalDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class NormalDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ const int i2 = 2 * i;
+ BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2],
+ sample[i2 + 3], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that returns standard normal distribution between
+// [-kTruncateValue, kTruncateValue].
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class SingleSampleGenerator, typename RealType>
+class TruncatedNormalDistribution;
+
+// Partial specialization for float.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ SingleSampleGenerator::kNativeElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (true) {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(f[i]) < kTruncateValue) {
+ results[index++] = f[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Partial specialization for double.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ (SingleSampleGenerator::kNativeElementCount > 1)
+ ? SingleSampleGenerator::kNativeElementCount / 2
+ : 1;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+ const double kTruncateValue = 2.0;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (1) {
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ const uint32 x2 = (*gen)();
+ const uint32 x3 = (*gen)();
+ double d[2];
+ BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(d[i]) < kTruncateValue) {
+ results[index++] = d[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Helper function to convert two 32-bit uniform integers to two floats
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const float epsilon = 1.0e-7f;
+ float u1 = Uint32ToFloat(x0);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const float v1 = 2.0f * M_PI * Uint32ToFloat(x1);
+ const float u2 = sqrt(-2.0f * log(u1));
+#if defined(__linux)
+ sincosf(v1, f0, f1);
+#else
+ *f0 = sinf(v1);
+ *f1 = cosf(v1);
+#endif
+ *f0 *= u2;
+ *f1 *= u2;
+}
+
+// Helper function to convert four 32-bit uniform integers to two doubles
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const double epsilon = 1.0e-7;
+ double u1 = Uint64ToDouble(x0, x1);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
+ const double u2 = sqrt(-2.0 * log(u1));
+#if defined(__linux)
+ sincos(v1, d0, d1);
+#else
+ *d0 = sin(v1);
+ *d1 = cos(v1);
+#endif
+ *d0 *= u2;
+ *d1 *= u2;
+}
+
+// Helper function to convert an 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
+ // IEEE754 floats are formatted as follows (MSB first):
+ // sign(1) exponent(8) mantissa(23)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 127 -- an excess 127 representation of a zero exponent
+ // mantissa == 23 random bits
+ const uint32 man = x & 0x7fffffu; // 23 bit mantissa
+ const uint32 exp = static_cast<uint32>(127);
+ const uint32 val = (exp << 23) | man;
+
+ // Assumes that endian-ness is same for float and uint32.
+ float result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0f;
+}
+
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) {
+ // IEEE754 doubles are formatted as follows (MSB first):
+ // sign(1) exponent(11) mantissa(52)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 1023 -- an excess 1023 representation of a zero exponent
+ // mantissa == 52 random bits
+ const uint32 mhi = x0 & 0xfffffu; // upper 20 bits of mantissa
+ const uint32 mlo = x1; // lower 32 bits of mantissa
+ const uint64 man = (static_cast<uint64>(mhi) << 32) | mlo; // mantissa
+ const uint64 exp = static_cast<uint64>(1023);
+ const uint64 val = (exp << 52) | man;
+ // Assumes that endian-ness is same for double and uint64.
+ double result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0;
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
new file mode 100644
index 0000000000..3ce86a907a
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -0,0 +1,270 @@
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// The largest z-value we want to tolerate. Since the z-test approximates a
+// unit normal distribution, it should almost definitely never exceed 6.
+static constexpr float kZLimit = 6.0;
+
+// A utility function to fill the given array with samples from the given
+// distribution, using the single adatper of the underlying generator
+template <class Distribution>
+void FillRandomsWithSingles(PhiloxRandom gen,
+ typename Distribution::ResultElementType* p,
+ int64 size) {
+ int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ auto sample = dist(&single_samples);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+// Check the given array of samples matches the given theoretical moment
+// function at different orders. The test is considered passing if the z-tests
+// of all statistical moments are all below z_limit.
+// typename T in the template argument could be either float or double.
+// Arguments:
+// samples: an array of samples to be tested for their statistical properties;
+// theoretical_moments: a functor that can calculate arbitrary order of
+// of the given distribution;
+// max_moments: the largest moments of the uniform distribution to be tested;
+// stride: the distance between samples to check for statistical properties
+// 0 means the n-th moment of each sample
+// any other strides tests for spatial correlation between samples;
+// z_limit: the maximum z-test we would consider the test to pass;
+template <typename T>
+bool CheckSamplesMoments(const std::vector<T>& samples,
+ std::function<double(int)> theoretical_moments,
+ int max_moments, int stride, T z_limit) {
+ const T* const samples_data = &samples[0];
+ const int samples_size = samples.size();
+ std::vector<double> moments(max_moments + 1);
+ double* const moments_data = &moments[0];
+ std::vector<int> moments_sample_count(max_moments + 1);
+ int* const moments_sample_count_data = &moments_sample_count[0];
+
+ for (int k = 0; k < samples_size; ++k) {
+ double moment = 1.;
+ for (int i = 0; i <= max_moments; ++i) {
+ int index = k + i * stride;
+ if (index >= samples_size) {
+ break;
+ }
+ // moments[i] store the i-th order measured moments.
+ // bypass std::vector::opeartor[] because they are too slow in the debug
+ // mode, given the large number of samples.
+ moments_data[i] += moment;
+ ++moments_sample_count_data[i];
+ moment *= samples_data[index];
+ }
+ }
+
+ // normalize the moments
+ for (int i = 0; i <= max_moments; ++i) {
+ moments[i] /= moments_sample_count[i];
+ }
+
+ bool status = true;
+
+ for (int i = 1; i <= max_moments; ++i) {
+ // Calculate the theoretical mean and variance
+ const double moments_i_mean = (stride == 0)
+ ? theoretical_moments(i)
+ : std::pow(theoretical_moments(1), i);
+ const double moments_i_squared = (stride == 0)
+ ? theoretical_moments(2 * i)
+ : std::pow(theoretical_moments(2), i);
+ const double moments_i_var =
+ moments_i_squared - moments_i_mean * moments_i_mean;
+
+ // assume every operation has a small numerical error.
+ static const double kNumericalError = 1e-6;
+ // it takes i multiplications to calculate one i-th moment.
+ const double error_per_moment = i * kNumericalError;
+ const double total_variance =
+ moments_i_var / moments_sample_count[i] + error_per_moment;
+ // z_test is approximately a unit normal distribution.
+ const double z_test =
+ fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
+
+ if (z_test > z_limit) {
+ LOG(ERROR) << "failing z_test:"
+ << " moment: " << i << " stride: " << stride
+ << " z_test: " << z_test << " z_limit: " << z_limit
+ << " measured moments: " << moments[i]
+ << " theoretical mean of the moments: " << moments_i_mean
+ << " theoretical var of the moments: " << moments_i_var
+ << " sample count: " << moments_sample_count[i];
+ status = false;
+ }
+ }
+
+ return status;
+}
+
+// This tests checks that the generated samples match the theoretical moments
+// of the uniform distribution.
+template <typename T>
+void UniformMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto uniform_moments = [](int n) -> double { return 1. / (n + 1); };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<UniformDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, uniform_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed;
+ }
+}
+
+// This test checks that the generated samples match the theoretical moments
+// of the unit normal distribution.
+template <typename T>
+void NormalMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto normal_moments = [](int n) -> double {
+ if (n % 2 == 1) {
+ // For an odd order, the moment of a unit normal distribution is zero.
+ return 0.;
+ } else {
+ // For an even order, the moment of a unit normal distribution is.
+ // (n-1)!!
+ double v = 1.;
+ for (int i = n - 1; i >= 1; i -= 2) {
+ v *= i;
+ }
+ return v;
+ }
+ };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<NormalDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, normal_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+// A functor to calculate the moments for the truncated normal distribution.
+// For any odd order, the moment is zero. But for any other n, it can be proven
+// that the following recursive relationship for the moments of the truncated
+// standard normal:
+// m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1)
+// where v is the cut-off value, f(v) is the p.d.f of the standard
+// normal, and Phi(v) is the c.d.f of the standard normal.
+class TruncatedNormalMoments {
+ public:
+ double operator()(int n) {
+ if (n == 0) {
+ return 1;
+ }
+ if (n % 2 == 1) {
+ // For an odd order, the moment is always zero
+ return 0.;
+ }
+
+ // Memoization and check the cached results.
+ auto iter = cached_results_.find(n);
+ if (iter != cached_results_.end()) {
+ return iter->second;
+ }
+
+ // The real computation of the moment.
+ double bias = 2.0 * std::pow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0);
+ double moment_n_minus_2 = (*this)(n - 2);
+ double moment_n = (n - 1) * moment_n_minus_2 - bias;
+
+ cached_results_[n] = moment_n;
+ return moment_n;
+ }
+
+ private:
+ const double kV = 2.0;
+ // f(v), where f is the p.d.f of the normal distribution and v=2.
+ const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0);
+ // The numerical evaluation of Phi(v), where v is the truncate value.
+ // v = 2 in the current implementation.
+ const double kPhiV = 0.977249868051821;
+ std::unordered_map<int, double> cached_results_;
+};
+
+// This test checks that the generated samples matche the theoretical moments
+// of the truncated normal distribution.
+template <typename T>
+void RandomParametersMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandomsWithSingles<
+ TruncatedNormalDistribution<SingleSampleAdapter<PhiloxRandom>, T> >(
+ gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, TruncatedNormalMoments(),
+ max_moments, stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<float>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, UniformDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<double>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random_test.cc b/tensorflow/core/lib/random/random_test.cc
new file mode 100644
index 0000000000..7ed37c8b5e
--- /dev/null
+++ b/tensorflow/core/lib/random/random_test.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <set>
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(New64Test, SanityCheck) {
+ std::set<uint64> values;
+ for (int i = 0; i < 1000000; i++) {
+ uint64 x = New64();
+ EXPECT_TRUE(values.insert(x).second) << "duplicate " << x;
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.cc b/tensorflow/core/lib/random/simple_philox.cc
new file mode 100644
index 0000000000..1035e1f017
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/random/exact_uniform_int.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+uint32 SimplePhilox::Uniform(uint32 n) {
+ return ExactUniformInt<uint32>(n, [this]() { return Rand32(); });
+}
+
+uint64 SimplePhilox::Uniform64(uint64 n) {
+ return ExactUniformInt<uint64>(n, [this]() { return Rand64(); });
+}
+
+uint32 SimplePhilox::Skewed(int max_log) {
+ CHECK(0 <= max_log && max_log <= 32);
+ const int shift = Rand32() % (max_log + 1);
+ const uint32 mask = shift == 32 ? ~static_cast<uint32>(0) : (1 << shift) - 1;
+ return Rand32() & mask;
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h
new file mode 100644
index 0000000000..12b15d7616
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.h
@@ -0,0 +1,61 @@
+#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace random {
+
+// A simple imperative interface to Philox
+class SimplePhilox {
+ public:
+ PHILOX_DEVICE_INLINE
+ explicit SimplePhilox(PhiloxRandom* gen) : single_(gen) {}
+
+ // 32 random bits
+ PHILOX_DEVICE_INLINE uint32 Rand32() { return single_(); }
+
+ // 64 random bits
+ PHILOX_DEVICE_INLINE uint64 Rand64() {
+ const uint32 lo = single_(), hi = single_();
+ return lo | static_cast<uint64>(hi) << 32;
+ }
+
+ // Uniform float in [0, 1)
+ PHILOX_DEVICE_INLINE float RandFloat() { return Uint32ToFloat(single_()); }
+
+ // Uniform double in [0, 1)
+ PHILOX_DEVICE_INLINE double RandDouble() {
+ const uint32 x0 = single_(), x1 = single_();
+ return Uint64ToDouble(x0, x1);
+ }
+
+ // Uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 32-bit sample.
+ uint32 Uniform(uint32 n);
+
+ // Approximately uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 64-bit sample.
+ uint64 Uniform64(uint64 n);
+
+ // True with probability 1/n.
+ bool OneIn(uint32 n) { return Uniform(n) == 0; }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with bias towards smaller numbers.
+ uint32 Skewed(int max_log);
+
+ private:
+ SingleSampleAdapter<PhiloxRandom> single_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
diff --git a/tensorflow/core/lib/random/simple_philox_test.cc b/tensorflow/core/lib/random/simple_philox_test.cc
new file mode 100644
index 0000000000..4246b8b4dd
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox_test.cc
@@ -0,0 +1,120 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+#include <set>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(SimplePhiloxTest, FloatTest) {
+ PhiloxRandom philox(7, 7);
+ SimplePhilox gen(&philox);
+ static const int kIters = 1000000;
+ for (int i = 0; i < kIters; ++i) {
+ float f = gen.RandFloat();
+ EXPECT_LE(0.0f, f);
+ EXPECT_GT(1.0f, f);
+ }
+ for (int i = 0; i < kIters; ++i) {
+ double d = gen.RandDouble();
+ EXPECT_LE(0.0, d);
+ EXPECT_GT(1.0, d);
+ }
+}
+
+static void DifferenceTest(const char *names, SimplePhilox *gen1,
+ SimplePhilox *gen2) {
+ static const int kIters = 100;
+ bool different = false;
+ for (int i = 0; i < kIters; ++i) {
+ if (gen1->Rand32() != gen2->Rand32()) {
+ different = true;
+ break;
+ }
+ }
+ CHECK(different) << "different seeds but same output!";
+}
+
+TEST(SimplePhiloxTest, DifferenceTest) {
+ PhiloxRandom philox1(1, 1), philox2(17, 17);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: different seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, DifferenceTestCloseSeeds) {
+ PhiloxRandom philox1(1, 1), philox2(2, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: close seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, Regression_CloseSeedsAreDifferent) {
+ const int kCount = 1000;
+
+ // Two seeds differ only by the last bit.
+ PhiloxRandom philox1(0, 1), philox2(1, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ std::set<uint32> first;
+ std::set<uint32> all;
+ for (int i = 0; i < kCount; ++i) {
+ uint32 v = gen1.Rand32();
+ first.insert(v);
+ all.insert(v);
+ all.insert(gen2.Rand32());
+ }
+
+ // Broken array initialization implementation (before 2009-08-18) using the
+ // above seeds return <1000, 1007>, generating output that is >99% similar.
+ // The fix returns <1000, 2000> for completely disjoint sets.
+ EXPECT_EQ(kCount, first.size());
+ EXPECT_EQ(2 * kCount, all.size());
+}
+
+TEST(SimplePhiloxTest, TestUniform) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint32 range = 3 * (1L << 29);
+ uint32 threshold = 1L << 30;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint32 rnd = gen.Uniform(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+TEST(SimplePhiloxTest, TestUniform64) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint64 range = 3 * (1LL << 59);
+ uint64 threshold = 1LL << 60;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint64 rnd = gen.Uniform64(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc
new file mode 100644
index 0000000000..f96da578ec
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.cc
@@ -0,0 +1,203 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+namespace random {
+
+WeightedPicker::WeightedPicker(int N) {
+ CHECK_GE(N, 0);
+ N_ = N;
+
+ // Find the number of levels
+ num_levels_ = 1;
+ while (LevelSize(num_levels_ - 1) < N) {
+ num_levels_++;
+ }
+
+ // Initialize the levels
+ level_ = new int32*[num_levels_];
+ for (int l = 0; l < num_levels_; l++) {
+ level_[l] = new int32[LevelSize(l)];
+ }
+
+ SetAllWeights(1);
+}
+
+WeightedPicker::~WeightedPicker() {
+ for (int l = 0; l < num_levels_; l++) {
+ delete[] level_[l];
+ }
+ delete[] level_;
+}
+
+static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
+ CHECK_LE(0, n);
+ const uint32 range = ~static_cast<uint32>(0);
+ if (n == 0) {
+ return r->Rand32() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return r->Rand32() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // Rand32's output is uniform in the half-open interval [0, 2^{32}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ uint32 rem = (range % n) + 1;
+ uint32 rnd;
+
+ // rem = ((2^{32}-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = r->Rand32(); // rnd uniform over [0, 2^{32})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{32})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
+ // = 2^{32}-1 - ((2^{32}-1) \bmod n)
+ // = n \cdot \lfloor (2^{32}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{32}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+int WeightedPicker::Pick(SimplePhilox* rnd) const {
+ if (total_weight() == 0) return -1;
+
+ // using unbiased uniform distribution to avoid bias
+ // toward low elements resulting from a possible use
+ // of big weights.
+ return PickAt(UnbiasedUniform(rnd, total_weight()));
+}
+
+int WeightedPicker::PickAt(int32 weight_index) const {
+ if (weight_index < 0 || weight_index >= total_weight()) return -1;
+
+ int32 position = weight_index;
+ int index = 0;
+
+ for (int l = 1; l < num_levels_; l++) {
+ // Pick left or right child of "level_[l-1][index]"
+ const int32 left_weight = level_[l][2 * index];
+ if (position < left_weight) {
+ // Descend to left child
+ index = 2 * index;
+ } else {
+ // Descend to right child
+ index = 2 * index + 1;
+ position -= left_weight;
+ }
+ }
+ CHECK_GE(index, 0);
+ CHECK_LT(index, N_);
+ CHECK_LE(position, level_[num_levels_ - 1][index]);
+ return index;
+}
+
+void WeightedPicker::set_weight(int index, int32 weight) {
+ assert(index >= 0);
+ assert(index < N_);
+
+ // Adjust the sums all the way up to the root
+ const int32 delta = weight - get_weight(index);
+ for (int l = num_levels_ - 1; l >= 0; l--) {
+ level_[l][index] += delta;
+ index >>= 1;
+ }
+}
+
+void WeightedPicker::SetAllWeights(int32 weight) {
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weight;
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
+ Resize(N);
+
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weights[i];
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::RebuildTreeWeights() {
+ for (int l = num_levels_ - 2; l >= 0; l--) {
+ int32* level = level_[l];
+ int32* children = level_[l + 1];
+ for (int i = 0; i < LevelSize(l); i++) {
+ level[i] = children[2 * i] + children[2 * i + 1];
+ }
+ }
+}
+
+void WeightedPicker::Append(int32 weight) {
+ Resize(num_elements() + 1);
+ set_weight(num_elements() - 1, weight);
+}
+
+void WeightedPicker::Resize(int new_size) {
+ CHECK_GE(new_size, 0);
+ if (new_size <= LevelSize(num_levels_ - 1)) {
+ // The new picker fits in the existing levels.
+
+ // First zero out any of the weights that are being dropped so
+ // that the levels are correct (only needed when shrinking)
+ for (int i = new_size; i < N_; i++) {
+ set_weight(i, 0);
+ }
+
+ // We do not need to set any new weights when enlarging because
+ // the unneeded entries always have weight zero.
+ N_ = new_size;
+ return;
+ }
+
+ // We follow the simple strategy of just copying the old
+ // WeightedPicker into a new WeightedPicker. The cost is
+ // O(N) regardless.
+ assert(new_size > N_);
+ WeightedPicker new_picker(new_size);
+ int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
+ int32* src = this->level_[this->num_levels_ - 1];
+ memcpy(dst, src, sizeof(dst[0]) * N_);
+ memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
+ new_picker.RebuildTreeWeights();
+
+ // Now swap the two pickers
+ std::swap(new_picker.N_, this->N_);
+ std::swap(new_picker.num_levels_, this->num_levels_);
+ std::swap(new_picker.level_, this->level_);
+ assert(this->N_ == new_size);
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h
new file mode 100644
index 0000000000..3d2c2dbb39
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.h
@@ -0,0 +1,118 @@
+
+// An abstraction to pick from one of N elements with a specified
+// weight per element.
+//
+// The weight for a given element can be changed in O(lg N) time
+// An element can be picked in O(lg N) time.
+//
+// Uses O(N) bytes of memory.
+//
+// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
+// adjustment after construction.
+
+#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class SimplePhilox;
+
+class WeightedPicker {
+ public:
+ // REQUIRES N >= 0
+ // Initializes the elements with a weight of one per element
+ explicit WeightedPicker(int N);
+
+ // Releases all resources
+ ~WeightedPicker();
+
+ // Pick a random element with probability proportional to its weight.
+ // If total weight is zero, returns -1.
+ int Pick(SimplePhilox* rnd) const;
+
+ // Deterministically pick element x whose weight covers the
+ // specified weight_index.
+ // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
+ int PickAt(int32 weight_index) const;
+
+ // Get the weight associated with an element
+ // REQUIRES 0 <= index < N
+ int32 get_weight(int index) const;
+
+ // Set the weight associated with an element
+ // REQUIRES weight >= 0.0f
+ // REQUIRES 0 <= index < N
+ void set_weight(int index, int32 weight);
+
+ // Get the total combined weight of all elements
+ int32 total_weight() const;
+
+ // Get the number of elements in the picker
+ int num_elements() const;
+
+ // Set weight of each element to "weight"
+ void SetAllWeights(int32 weight);
+
+ // Resizes the picker to N and
+ // sets the weight of each element i to weight[i].
+ // The sum of the weights should not exceed 2^31 - 2
+ // Complexity O(N).
+ void SetWeightsFromArray(int N, const int32* weights);
+
+ // REQUIRES N >= 0
+ //
+ // Resize the weighted picker so that it has "N" elements.
+ // Any newly added entries have zero weight.
+ //
+ // Note: Resizing to a smaller size than num_elements() will
+ // not reclaim any memory. If you wish to reduce memory usage,
+ // allocate a new WeightedPicker of the appropriate size.
+ //
+ // It is efficient to use repeated calls to Resize(num_elements() + 1)
+ // to grow the picker to size X (takes total time O(X)).
+ void Resize(int N);
+
+ // Grow the picker by one and set the weight of the new entry to "weight".
+ //
+ // Repeated calls to Append() in order to grow the
+ // picker to size X takes a total time of O(X lg(X)).
+ // Consider using SetWeightsFromArray instead.
+ void Append(int32 weight);
+
+ private:
+ // We keep a binary tree with N leaves. The "i"th leaf contains
+ // the weight of the "i"th element. An internal node contains
+ // the sum of the weights of its children.
+ int N_; // Number of elements
+ int num_levels_; // Number of levels in tree (level-0 is root)
+ int32** level_; // Array that holds nodes per level
+
+ // Size of each level
+ static int LevelSize(int level) { return 1 << level; }
+
+ // Rebuild the tree weights using the leaf weights
+ void RebuildTreeWeights();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
+};
+
+inline int32 WeightedPicker::get_weight(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, N_);
+ return level_[num_levels_ - 1][index];
+}
+
+inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
+
+inline int WeightedPicker::num_elements() const { return N_; }
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc
new file mode 100644
index 0000000000..0b27d437d5
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker_test.cc
@@ -0,0 +1,254 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+static void TestPicker(SimplePhilox* rnd, int size);
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void TestPickAt(int items, const int32* weights);
+
+TEST(WeightedPicker, Simple) {
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox rnd(&philox);
+
+ {
+ VLOG(0) << "======= Zero-length picker";
+ WeightedPicker picker(0);
+ EXPECT_EQ(picker.Pick(&rnd), -1);
+ }
+
+ {
+ VLOG(0) << "======= Singleton picker";
+ WeightedPicker picker(1);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker";
+ WeightedPicker picker(0);
+ for (int i = 0; i < 10; i++) {
+ picker.Append(1);
+ }
+ CheckUniform(&rnd, &picker, 100000);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker with zero weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Shrink picker and check weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ for (int i = 0; i < 10; i++) {
+ picker.set_weight(i, i);
+ }
+ EXPECT_EQ(picker.total_weight(), 45);
+ picker.Resize(5);
+ EXPECT_EQ(picker.total_weight(), 10);
+ picker.Resize(2);
+ EXPECT_EQ(picker.total_weight(), 1);
+ picker.Resize(1);
+ EXPECT_EQ(picker.total_weight(), 0);
+ }
+}
+
+TEST(WeightedPicker, BigWeights) {
+ PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ SimplePhilox rnd(&philox);
+ VLOG(0) << "======= Check uniform with big weights";
+ WeightedPicker picker(2);
+ picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3
+ CheckUniform(&rnd, &picker, 100000);
+}
+
+TEST(WeightedPicker, Deterministic) {
+ VLOG(0) << "======= Testing deterministic pick";
+ static const int32 weights[] = {1, 0, 200, 5, 42};
+ TestPickAt(TF_ARRAYSIZE(weights), weights);
+}
+
+TEST(WeightedPicker, Randomized) {
+ PhiloxRandom philox(testing::RandomSeed() + 10, 17);
+ SimplePhilox rnd(&philox);
+ TestPicker(&rnd, 1);
+ TestPicker(&rnd, 2);
+ TestPicker(&rnd, 3);
+ TestPicker(&rnd, 4);
+ TestPicker(&rnd, 7);
+ TestPicker(&rnd, 8);
+ TestPicker(&rnd, 9);
+ TestPicker(&rnd, 10);
+ TestPicker(&rnd, 100);
+}
+
+static void TestPicker(SimplePhilox* rnd, int size) {
+ VLOG(0) << "======= Testing size " << size;
+
+ // Check that empty picker returns -1
+ {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1);
+ }
+
+ // Create zero weights array
+ std::vector<int32> weights(size);
+ for (int elem = 0; elem < size; elem++) {
+ weights[elem] = 0;
+ }
+
+ // Check that singleton picker always returns the same element
+ for (int elem = 0; elem < size; elem++) {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ picker.set_weight(elem, elem + 1);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 10;
+ picker.SetWeightsFromArray(size, &weights[0]);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 0;
+ }
+
+ // Check that uniform picker generates elements roughly uniformly
+ {
+ WeightedPicker picker(size);
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check uniform picker that was grown piecemeal
+ if (size / 3 > 0) {
+ WeightedPicker picker(size / 3);
+ while (picker.num_elements() != size) {
+ picker.Append(1);
+ }
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check that skewed distribution works
+ if (size <= 10) {
+ // When picker grows one element at a time
+ WeightedPicker picker(size);
+ int32 weight = 1;
+ for (int elem = 0; elem < size; elem++) {
+ picker.set_weight(elem, weight);
+ weights[elem] = weight;
+ weight *= 2;
+ }
+ CheckSkewed(rnd, &picker, 1000000);
+
+ // When picker is created from an array
+ WeightedPicker array_picker(0);
+ array_picker.SetWeightsFromArray(size, &weights[0]);
+ CheckSkewed(rnd, &array_picker, 1000000);
+ }
+}
+
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker,
+ int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+ const int expected_min = int(0.9 * trials);
+ const int expected_max = int(1.1 * trials);
+ for (int i = 0; i < size; i++) {
+ EXPECT_GE(count[i], expected_min);
+ EXPECT_LE(count[i], expected_max);
+ }
+ delete[] count;
+}
+
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+
+ for (int i = 0; i < size - 1; i++) {
+ LOG(INFO) << i << ": " << count[i];
+ const float ratio = float(count[i + 1]) / float(count[i]);
+ EXPECT_GE(ratio, 1.6f);
+ EXPECT_LE(ratio, 2.4f);
+ }
+ delete[] count;
+}
+
+static void TestPickAt(int items, const int32* weights) {
+ WeightedPicker picker(items);
+ picker.SetWeightsFromArray(items, weights);
+ int weight_index = 0;
+ for (int i = 0; i < items; ++i) {
+ for (int j = 0; j < weights[i]; ++j) {
+ int pick = picker.PickAt(weight_index);
+ EXPECT_EQ(pick, i);
+ ++weight_index;
+ }
+ }
+ EXPECT_EQ(weight_index, picker.total_weight());
+}
+
+static void BM_Create(int iters, int arg) {
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ }
+}
+BENCHMARK(BM_Create)->Range(1, 1024);
+
+static void BM_CreateAndSetWeights(int iters, int arg) {
+ std::vector<int32> weights(arg);
+ for (int i = 0; i < arg; i++) {
+ weights[i] = i * 10;
+ }
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ p.SetWeightsFromArray(arg, &weights[0]);
+ }
+}
+BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024);
+
+static void BM_Pick(int iters, int arg) {
+ PhiloxRandom philox(301, 17);
+ SimplePhilox rnd(&philox);
+ WeightedPicker p(arg);
+ int result = 0;
+ while (--iters > 0) {
+ result += p.Pick(&rnd);
+ }
+ VLOG(4) << result; // Dummy use
+}
+BENCHMARK(BM_Pick)->Range(1, 1024);
+
+} // namespace random
+} // namespace tensorflow