aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/guarded_philox_random.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/guarded_philox_random.cc')
-rw-r--r--tensorflow/core/util/guarded_philox_random.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc
new file mode 100644
index 0000000000..4cf58b8979
--- /dev/null
+++ b/tensorflow/core/util/guarded_philox_random.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) {
+ // Grab seed Attrs.
+ int64 seed, seed2;
+ auto status = context->GetAttr("seed", &seed);
+ if (!status.ok()) return status;
+ status = context->GetAttr("seed2", &seed2);
+ if (!status.ok()) return status;
+
+ // Initialize with the given seeds
+ Init(seed, seed2);
+ return Status::OK();
+}
+
+void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) {
+ CHECK(!initialized_);
+ if (seed == 0 && seed2 == 0) {
+ // If both seeds are unspecified, use completely random seeds.
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+ mutex_lock lock(mu_);
+ generator_ = random::PhiloxRandom(seed, seed2);
+ initialized_ = true;
+}
+
+random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) {
+ CHECK(initialized_);
+ mutex_lock lock(mu_);
+ auto local = generator_;
+ generator_.Skip(samples);
+ return local;
+}
+
+} // namespace tensorflow