aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/random/philox_random_test_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/random/philox_random_test_utils.h')
-rw-r--r--tensorflow/core/lib/random/philox_random_test_utils.h36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h
new file mode 100644
index 0000000000..d22f6b36e4
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test_utils.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a random seed.
+inline uint64 GetTestSeed() { return New64(); }
+
+// A utility function to fill the given array with samples from the given
+// distribution.
+template <class Distribution>
+void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
+ int64 size) {
+ const int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ const auto sample = dist(&gen);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_