diff options
Diffstat (limited to 'tensorflow/core/util/guarded_philox_random.h')
-rw-r--r-- | tensorflow/core/util/guarded_philox_random.h | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h new file mode 100644 index 0000000000..6e9cb9f99c --- /dev/null +++ b/tensorflow/core/util/guarded_philox_random.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ +#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// A thread safe wrapper around a Philox generator. Example usage: +// +// GuardedRandomPhilox generator; +// generator.Init(context); +// +// // In thread safe code +// const int samples = ...; +// auto local_generator = generator.ReserveSamples128(samples); +// for (int i = 0; i < samples; i++) +// Array<uint32, 4> sample = local_generator(); +// // Use sample +// } +// +class GuardedPhiloxRandom { + public: + // Must call Init to finish initialization + GuardedPhiloxRandom() : initialized_(false) {} + + // Initialize the generator from attributes "seed" and "seed2". + // If both seeds are unspecified, use random seeds. + // Must be called exactly once. + Status Init(OpKernelConstruction* context); + + // Initialize with given seeds. + void Init(int64 seed, int64 seed2); + + // Reserve a certain number of 128-bit samples. + // This function is thread safe. The returned generator is valid for the + // given number of samples, and can be used without a lock. + random::PhiloxRandom ReserveSamples128(int64 samples); + + // Reserve a certain number of 32-bit samples + random::PhiloxRandom ReserveSamples32(int64 samples) { + return ReserveSamples128((samples + 3) / 4); + } + + private: + mutex mu_; + random::PhiloxRandom generator_ GUARDED_BY(mu_); + bool initialized_; + + TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ |