diff options
Diffstat (limited to 'tensorflow/core/lib/random/philox_random_test.cc')
-rw-r--r-- | tensorflow/core/lib/random/philox_random_test.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc new file mode 100644 index 0000000000..997c0263b7 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random_test.cc @@ -0,0 +1,58 @@ +#include "tensorflow/core/lib/random/philox_random.h" + +#include <math.h> +#include <algorithm> +#include <functional> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/random/philox_random_test_utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace random { +namespace { + +// A trivial distribution that just returns the PhiloxRandom as a distribution +class TrivialPhiloxDistribution { + public: + // The number of elements that will be returned. + static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount; + typedef PhiloxRandom::ResultType ResultType; + typedef PhiloxRandom::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(PhiloxRandom* gen) { return (*gen)(); } +}; + +// This test checks that skipping certain number of samples, is equivalent to +// generate the same number of samples without skipping. +TEST(PhiloxRandomTest, SkipMatchTest) { + constexpr int count = 1024; + constexpr int skip_count = 2048; + + uint64 test_seed = GetTestSeed(); + std::vector<uint32> v1(count); + { + PhiloxRandom gen(test_seed); + gen.Skip(skip_count / 4); + FillRandoms<TrivialPhiloxDistribution>(gen, &v1[0], v1.size()); + } + + std::vector<uint32> v2(count + skip_count); + { + PhiloxRandom gen(test_seed); + FillRandoms<TrivialPhiloxDistribution>(gen, &v2[0], v2.size()); + } + + for (int i = 0; i < count; ++i) { + ASSERT_EQ(v1[i], v2[i + skip_count]); + } +} + +} // namespace +} // namespace random +} // namespace tensorflow |