blob: 6e9cb9f99c9eee7b1db09f4e692bb7f1ea5aae5a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
// A thread safe wrapper around a Philox generator. Example usage:
//
// GuardedRandomPhilox generator;
// generator.Init(context);
//
// // In thread safe code
// const int samples = ...;
// auto local_generator = generator.ReserveSamples128(samples);
// for (int i = 0; i < samples; i++)
// Array<uint32, 4> sample = local_generator();
// // Use sample
// }
//
class GuardedPhiloxRandom {
public:
// Must call Init to finish initialization
GuardedPhiloxRandom() : initialized_(false) {}
// Initialize the generator from attributes "seed" and "seed2".
// If both seeds are unspecified, use random seeds.
// Must be called exactly once.
Status Init(OpKernelConstruction* context);
// Initialize with given seeds.
void Init(int64 seed, int64 seed2);
// Reserve a certain number of 128-bit samples.
// This function is thread safe. The returned generator is valid for the
// given number of samples, and can be used without a lock.
random::PhiloxRandom ReserveSamples128(int64 samples);
// Reserve a certain number of 32-bit samples
random::PhiloxRandom ReserveSamples32(int64 samples) {
return ReserveSamples128((samples + 3) / 4);
}
private:
mutex mu_;
random::PhiloxRandom generator_ GUARDED_BY(mu_);
bool initialized_;
TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom);
};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
|