aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_op_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_op_test.cc')
-rw-r--r--tensorflow/core/kernels/random_op_test.cc99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/random_op_test.cc b/tensorflow/core/kernels/random_op_test.cc
new file mode 100644
index 0000000000..751b61cfba
--- /dev/null
+++ b/tensorflow/core/kernels/random_op_test.cc
@@ -0,0 +1,99 @@
+#include <random>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+Tensor Int32(int32 v) {
+ Tensor t(DT_INT32, TensorShape({}));
+ t.scalar<int32>()() = v;
+ return t;
+}
+
+Graph* RandomUniform(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomUniform(g, test::graph::Constant(g, Int32(n)), DT_FLOAT);
+ return g;
+}
+
+Graph* RandomNormal(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomGaussian(g, test::graph::Constant(g, Int32(n)), DT_FLOAT);
+ return g;
+}
+
+Graph* RandomParameters(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomParameters(g, test::graph::Constant(g, Int32(n)),
+ DT_FLOAT);
+ return g;
+}
+
+#define BM_RNG(DEVICE, RNG) \
+ static void BM_##DEVICE##_##RNG(int iters, int arg) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
+ test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
+
+BM_RNG(cpu, RandomUniform);
+BM_RNG(cpu, RandomNormal);
+BM_RNG(cpu, RandomParameters);
+
+BM_RNG(gpu, RandomUniform);
+BM_RNG(gpu, RandomNormal);
+BM_RNG(gpu, RandomParameters);
+
+static void BM_PhiloxRandom(int iters) {
+ // Fill 2M random numbers
+ int count = 2 << 20;
+
+ testing::ItemsProcessed(static_cast<int64>(iters) * count);
+
+ random::PhiloxRandom gen(0x12345);
+
+ int val = 1;
+ for (int i = 0; i < iters; ++i) {
+ for (int j = 0; j < count; j += 4) {
+ /// each invocation of gen() returns 128-bit samples
+ auto samples = gen();
+
+ // use the result trivially so the compiler does not optimize it away
+ val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3];
+ }
+ }
+
+ // A anchor point to make sure the compiler does not cut corners
+ CHECK(val) << val;
+}
+BENCHMARK(BM_PhiloxRandom);
+
+static void BM_StdMTRandom(int iters) {
+ // Fill 2M random numbers
+ int count = 2 << 20;
+
+ testing::ItemsProcessed(static_cast<int64>(iters) * count);
+
+ std::mt19937 gen(0x12345);
+
+ int val = 1;
+ for (int i = 0; i < iters; ++i) {
+ for (int j = 0; j < count; ++j) {
+ /// each invocation of gen() returns 32-bit sample
+ uint32 sample = gen();
+
+ // use the result trivially so the compiler does not optimize it away
+ val ^= sample;
+ }
+ }
+
+ // A anchor point to make sure the compiler does not cut corners
+ CHECK(val) << val;
+}
+BENCHMARK(BM_StdMTRandom);
+
+} // end namespace tensorflow