aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_rng.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_rng.h')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.h89
1 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.h b/tensorflow/stream_executor/cuda/cuda_rng.h
new file mode 100644
index 0000000000..4e1b82969b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_rng.h
@@ -0,0 +1,89 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
+
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/rng.h"
+
+typedef struct curandGenerator_st *curandGenerator_t;
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace cuda {
+
+// Opaque and unique identifier for the cuRAND plugin.
+extern const PluginId kCuRandPlugin;
+
+class CUDAExecutor;
+
+// CUDA-platform implementation of the random number generation support
+// interface.
+//
+// Thread-safe post-initialization.
+class CUDARng : public rng::RngSupport {
+ public:
+ explicit CUDARng(CUDAExecutor *parent);
+
+ // Retrieves a curand library generator handle. This is necessary for
+ // enqueuing random number generation work onto the device.
+ // TODO(leary) provide a way for users to select the RNG algorithm.
+ bool Init();
+
+ // Releases a curand library generator handle, if one was acquired.
+ ~CUDARng() override;
+
+ // See rng::RngSupport for details on the following overrides.
+ bool DoPopulateRandUniform(Stream *stream, DeviceMemory<float> *v) override;
+ bool DoPopulateRandUniform(Stream *stream, DeviceMemory<double> *v) override;
+ bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<float>> *v) override;
+ bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<double>> *v) override;
+ bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
+ DeviceMemory<float> *v) override;
+ bool DoPopulateRandGaussian(Stream *stream, double mean, double stddev,
+ DeviceMemory<double> *v) override;
+
+ bool SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) override;
+
+ private:
+ // Actually performs the work of generating random numbers - the public
+ // methods are thin wrappers to this interface.
+ template <typename T>
+ bool DoPopulateRandUniformInternal(Stream *stream, DeviceMemory<T> *v);
+ template <typename ElemT, typename FuncT>
+ bool DoPopulateRandGaussianInternal(Stream *stream, ElemT mean, ElemT stddev,
+ DeviceMemory<ElemT> *v, FuncT func);
+
+ // Sets the stream for the internal curand generator.
+ //
+ // This is a stateful operation, as the handle can only have one stream set at
+ // a given time, so it is usually performed right before enqueuing work to do
+ // with random number generation.
+ bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // mutex that guards the cuRAND handle for this device.
+ mutex mu_;
+
+ // CUDAExecutor which instantiated this CUDARng.
+ // Immutable post-initialization.
+ CUDAExecutor *parent_;
+
+ // cuRANDalibrary handle on the device.
+ curandGenerator_t rng_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDARng);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_