diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/stream_executor/cuda/cuda_rng.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_rng.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_rng.cc | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc new file mode 100644 index 0000000000..ad48c8b59a --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_rng.cc @@ -0,0 +1,317 @@ +#include "tensorflow/stream_executor/cuda/cuda_rng.h" + +#include <dlfcn.h> + +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#include "tensorflow/stream_executor/cuda/cuda_helpers.h" +#include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/dso_loader.h" +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/rng.h" +#include "third_party/gpus/cuda/include/curand.h" + +// Formats curandStatus_t to output prettified values into a log stream. +std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) { +#define OSTREAM_CURAND_STATUS(__name) \ + case CURAND_STATUS_##__name: \ + in << "CURAND_STATUS_" #__name; \ + return in; + + switch (status) { + OSTREAM_CURAND_STATUS(SUCCESS) + OSTREAM_CURAND_STATUS(VERSION_MISMATCH) + OSTREAM_CURAND_STATUS(NOT_INITIALIZED) + OSTREAM_CURAND_STATUS(ALLOCATION_FAILED) + OSTREAM_CURAND_STATUS(TYPE_ERROR) + OSTREAM_CURAND_STATUS(OUT_OF_RANGE) + OSTREAM_CURAND_STATUS(LENGTH_NOT_MULTIPLE) + OSTREAM_CURAND_STATUS(LAUNCH_FAILURE) + OSTREAM_CURAND_STATUS(PREEXISTING_FAILURE) + OSTREAM_CURAND_STATUS(INITIALIZATION_FAILED) + OSTREAM_CURAND_STATUS(ARCH_MISMATCH) + OSTREAM_CURAND_STATUS(INTERNAL_ERROR) + default: + in << "curandStatus_t(" << static_cast<int>(status) << ")"; + return in; + } +} + +namespace perftools { +namespace gputools { +namespace cuda { + +PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin); + +namespace dynload { + +#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char *kName; \ + using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \ + static void *GetDsoHandle() { \ + static auto status = internal::CachedDsoLoader::GetCurandDsoHandle(); \ + return status.ValueOrDie(); \ + } \ + static FuncPointerT DynLoad() { \ + static void *f = dlsym(GetDsoHandle(), kName); \ + CHECK(f != nullptr) << "could not find " << kName \ + << " in curand DSO; dlerror: " << dlerror(); \ + return reinterpret_cast<FuncPointerT>(f); \ + } \ + template <typename... Args> \ + curandStatus_t operator()(CUDAExecutor * parent, Args... args) { \ + cuda::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char *DynLoadShim__##__name::kName = #__name; + +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetStream); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniform); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniformDouble); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal); +PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble); + +} // namespace dynload + +template <typename T> +string TypeString(); + +template <> +string TypeString<float>() { + return "float"; +} + +template <> +string TypeString<double>() { + return "double"; +} + +template <> +string TypeString<std::complex<float>>() { + return "std::complex<float>"; +} + +template <> +string TypeString<std::complex<double>>() { + return "std::complex<double>"; +} + +CUDARng::CUDARng(CUDAExecutor *parent) : parent_(parent), rng_(nullptr) {} + +CUDARng::~CUDARng() { + if (rng_ != nullptr) { + dynload::curandDestroyGenerator(parent_, rng_); + } +} + +bool CUDARng::Init() { + mutex_lock lock{mu_}; + CHECK(rng_ == nullptr); + + curandStatus_t ret = + dynload::curandCreateGenerator(parent_, &rng_, CURAND_RNG_PSEUDO_DEFAULT); + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to create random number generator: " << ret; + return false; + } + + CHECK(rng_ != nullptr); + return true; +} + +bool CUDARng::SetStream(Stream *stream) { + curandStatus_t ret = + dynload::curandSetStream(parent_, rng_, AsCUDAStreamValue(stream)); + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to set stream for random generation: " << ret; + return false; + } + + return true; +} + +// Returns true if std::complex stores its contents as two consecutive +// elements. Tests int, float and double, as the last two are independent +// specializations. +constexpr bool ComplexIsConsecutiveFloats() { + return sizeof(std::complex<int>) == 8 && sizeof(std::complex<float>) == 8 && + sizeof(std::complex<double>) == 16; +} + +template <typename T> +bool CUDARng::DoPopulateRandUniformInternal(Stream *stream, + DeviceMemory<T> *v) { + mutex_lock lock{mu_}; + static_assert(ComplexIsConsecutiveFloats(), + "std::complex values are not stored as consecutive values"); + + if (!SetStream(stream)) { + return false; + } + + // std::complex<T> is currently implemented as two consecutive T variables. + uint64 element_count = v->ElementCount(); + if (std::is_same<T, std::complex<float>>::value || + std::is_same<T, std::complex<double>>::value) { + element_count *= 2; + } + + curandStatus_t ret; + if (std::is_same<T, float>::value || + std::is_same<T, std::complex<float>>::value) { + ret = dynload::curandGenerateUniform( + parent_, rng_, reinterpret_cast<float *>(CUDAMemoryMutable(v)), + element_count); + } else { + ret = dynload::curandGenerateUniformDouble( + parent_, rng_, reinterpret_cast<double *>(CUDAMemoryMutable(v)), + element_count); + } + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount() + << " " << TypeString<T>() << "s at " << v->opaque() << ": " + << ret; + return false; + } + + return true; +} + +bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory<float> *v) { + return DoPopulateRandUniformInternal(stream, v); +} + +bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory<double> *v) { + return DoPopulateRandUniformInternal(stream, v); +} + +bool CUDARng::DoPopulateRandUniform(Stream *stream, + DeviceMemory<std::complex<float>> *v) { + return DoPopulateRandUniformInternal(stream, v); +} + +bool CUDARng::DoPopulateRandUniform(Stream *stream, + DeviceMemory<std::complex<double>> *v) { + return DoPopulateRandUniformInternal(stream, v); +} + +template <typename ElemT, typename FuncT> +bool CUDARng::DoPopulateRandGaussianInternal(Stream *stream, ElemT mean, + ElemT stddev, + DeviceMemory<ElemT> *v, + FuncT func) { + mutex_lock lock{mu_}; + + if (!SetStream(stream)) { + return false; + } + + uint64 element_count = v->ElementCount(); + curandStatus_t ret = + func(parent_, rng_, CUDAMemoryMutable(v), element_count, mean, stddev); + + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount() + << " floats at " << v->opaque() << ": " << ret; + return false; + } + + return true; +} + +bool CUDARng::DoPopulateRandGaussian(Stream *stream, float mean, float stddev, + DeviceMemory<float> *v) { + return DoPopulateRandGaussianInternal(stream, mean, stddev, v, + dynload::curandGenerateNormal); +} + +bool CUDARng::DoPopulateRandGaussian(Stream *stream, double mean, double stddev, + DeviceMemory<double> *v) { + return DoPopulateRandGaussianInternal(stream, mean, stddev, v, + dynload::curandGenerateNormalDouble); +} + +bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) { + mutex_lock lock{mu_}; + CHECK(rng_ != nullptr); + + if (!CheckSeed(seed, seed_bytes)) { + return false; + } + + if (!SetStream(stream)) { + return false; + } + + // Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above) + // (which itself requires 16 for API consistency with host RNG fallbacks). + curandStatus_t ret = dynload::curandSetPseudoRandomGeneratorSeed( + parent_, rng_, *(reinterpret_cast<const uint64 *>(seed))); + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to set rng seed: " << ret; + return false; + } + + ret = dynload::curandSetGeneratorOffset(parent_, rng_, 0); + if (ret != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "failed to reset rng position: " << ret; + return false; + } + return true; +} + +} // namespace cuda +} // namespace gputools +} // namespace perftools + +namespace gpu = ::perftools::gputools; + +REGISTER_MODULE_INITIALIZER(register_curand, { + gpu::port::Status status = + gpu::PluginRegistry::Instance() + ->RegisterFactory<gpu::PluginRegistry::RngFactory>( + gpu::cuda::kCudaPlatformId, gpu::cuda::kCuRandPlugin, "cuRAND", + [](gpu::internal::StreamExecutorInterface + *parent) -> gpu::rng::RngSupport * { + gpu::cuda::CUDAExecutor *cuda_executor = + dynamic_cast<gpu::cuda::CUDAExecutor *>(parent); + if (cuda_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the cuRAND " + << "support library with a non-CUDA StreamExecutor"; + return nullptr; + } + + gpu::cuda::CUDARng *rng = new gpu::cuda::CUDARng(cuda_executor); + if (!rng->Init()) { + // Note: Init() will log a more specific error. + delete rng; + return nullptr; + } + return rng; + }); + + if (!status.ok()) { + LOG(ERROR) << "Unable to register cuRAND factory: " + << status.error_message(); + } + + // Prime the cuRAND DSO. The loader will log more information. + auto statusor = gpu::internal::CachedDsoLoader::GetCurandDsoHandle(); + if (!statusor.ok()) { + LOG(INFO) << "Unable to load cuRAND DSO."; + } + + gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, + gpu::PluginKind::kRng, + gpu::cuda::kCuRandPlugin); +}); |