/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/stream_executor/cuda/cuda_rng.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_helpers.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/rng.h" #include "cuda/include/curand.h" // Formats curandStatus_t to output prettified values into a log stream. std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) { #define OSTREAM_CURAND_STATUS(__name) \ case CURAND_STATUS_##__name: \ in << "CURAND_STATUS_" #__name; \ return in; switch (status) { OSTREAM_CURAND_STATUS(SUCCESS) OSTREAM_CURAND_STATUS(VERSION_MISMATCH) OSTREAM_CURAND_STATUS(NOT_INITIALIZED) OSTREAM_CURAND_STATUS(ALLOCATION_FAILED) OSTREAM_CURAND_STATUS(TYPE_ERROR) OSTREAM_CURAND_STATUS(OUT_OF_RANGE) OSTREAM_CURAND_STATUS(LENGTH_NOT_MULTIPLE) OSTREAM_CURAND_STATUS(LAUNCH_FAILURE) OSTREAM_CURAND_STATUS(PREEXISTING_FAILURE) OSTREAM_CURAND_STATUS(INITIALIZATION_FAILED) OSTREAM_CURAND_STATUS(ARCH_MISMATCH) OSTREAM_CURAND_STATUS(INTERNAL_ERROR) default: in << "curandStatus_t(" << static_cast(status) << ")"; return in; } } namespace stream_executor { namespace cuda { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin); namespace wrap { #define STREAM_EXECUTOR_CURAND_WRAP(__name) \ struct WrapperShim__##__name { \ template \ curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \ cuda::ScopedActivateExecutorContext sac{parent}; \ return ::__name(args...); \ } \ } __name; STREAM_EXECUTOR_CURAND_WRAP(curandCreateGenerator); STREAM_EXECUTOR_CURAND_WRAP(curandDestroyGenerator); STREAM_EXECUTOR_CURAND_WRAP(curandSetStream); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniform); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniformDouble); STREAM_EXECUTOR_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed); STREAM_EXECUTOR_CURAND_WRAP(curandSetGeneratorOffset); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormal); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormalDouble); } // namespace wrap template string TypeString(); template <> string TypeString() { return "float"; } template <> string TypeString() { return "double"; } template <> string TypeString>() { return "std::complex"; } template <> string TypeString>() { return "std::complex"; } CUDARng::CUDARng(CUDAExecutor *parent) : parent_(parent), rng_(nullptr) {} CUDARng::~CUDARng() { if (rng_ != nullptr) { wrap::curandDestroyGenerator(parent_, rng_); } } bool CUDARng::Init() { mutex_lock lock(mu_); CHECK(rng_ == nullptr); curandStatus_t ret = wrap::curandCreateGenerator(parent_, &rng_, CURAND_RNG_PSEUDO_DEFAULT); if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to create random number generator: " << ret; return false; } CHECK(rng_ != nullptr); return true; } bool CUDARng::SetStream(Stream *stream) { curandStatus_t ret = wrap::curandSetStream(parent_, rng_, AsCUDAStreamValue(stream)); if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to set stream for random generation: " << ret; return false; } return true; } // Returns true if std::complex stores its contents as two consecutive // elements. Tests int, float and double, as the last two are independent // specializations. constexpr bool ComplexIsConsecutiveFloats() { return sizeof(std::complex) == 8 && sizeof(std::complex) == 8 && sizeof(std::complex) == 16; } template bool CUDARng::DoPopulateRandUniformInternal(Stream *stream, DeviceMemory *v) { mutex_lock lock(mu_); static_assert(ComplexIsConsecutiveFloats(), "std::complex values are not stored as consecutive values"); if (!SetStream(stream)) { return false; } // std::complex is currently implemented as two consecutive T variables. uint64 element_count = v->ElementCount(); if (std::is_same>::value || std::is_same>::value) { element_count *= 2; } curandStatus_t ret; if (std::is_same::value || std::is_same>::value) { ret = wrap::curandGenerateUniform( parent_, rng_, reinterpret_cast(CUDAMemoryMutable(v)), element_count); } else { ret = wrap::curandGenerateUniformDouble( parent_, rng_, reinterpret_cast(CUDAMemoryMutable(v)), element_count); } if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount() << " " << TypeString() << "s at " << v->opaque() << ": " << ret; return false; } return true; } bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory *v) { return DoPopulateRandUniformInternal(stream, v); } bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory *v) { return DoPopulateRandUniformInternal(stream, v); } bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory> *v) { return DoPopulateRandUniformInternal(stream, v); } bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory> *v) { return DoPopulateRandUniformInternal(stream, v); } template bool CUDARng::DoPopulateRandGaussianInternal(Stream *stream, ElemT mean, ElemT stddev, DeviceMemory *v, FuncT func) { mutex_lock lock(mu_); if (!SetStream(stream)) { return false; } uint64 element_count = v->ElementCount(); curandStatus_t ret = func(parent_, rng_, CUDAMemoryMutable(v), element_count, mean, stddev); if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount() << " floats at " << v->opaque() << ": " << ret; return false; } return true; } bool CUDARng::DoPopulateRandGaussian(Stream *stream, float mean, float stddev, DeviceMemory *v) { return DoPopulateRandGaussianInternal(stream, mean, stddev, v, wrap::curandGenerateNormal); } bool CUDARng::DoPopulateRandGaussian(Stream *stream, double mean, double stddev, DeviceMemory *v) { return DoPopulateRandGaussianInternal(stream, mean, stddev, v, wrap::curandGenerateNormalDouble); } bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) { mutex_lock lock(mu_); CHECK(rng_ != nullptr); if (!CheckSeed(seed, seed_bytes)) { return false; } if (!SetStream(stream)) { return false; } // Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above) // (which itself requires 16 for API consistency with host RNG fallbacks). curandStatus_t ret = wrap::curandSetPseudoRandomGeneratorSeed( parent_, rng_, *(reinterpret_cast(seed))); if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to set rng seed: " << ret; return false; } ret = wrap::curandSetGeneratorOffset(parent_, rng_, 0); if (ret != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "failed to reset rng position: " << ret; return false; } return true; } } // namespace cuda void initialize_curand() { port::Status status = PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, cuda::kCuRandPlugin, "cuRAND", [](internal::StreamExecutorInterface *parent) -> rng::RngSupport * { cuda::CUDAExecutor *cuda_executor = dynamic_cast(parent); if (cuda_executor == nullptr) { LOG(ERROR) << "Attempting to initialize an instance of the cuRAND " << "support library with a non-CUDA StreamExecutor"; return nullptr; } cuda::CUDARng *rng = new cuda::CUDARng(cuda_executor); if (!rng->Init()) { // Note: Init() will log a more specific error. delete rng; return nullptr; } return rng; }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuRAND factory: " << status.error_message(); } PluginRegistry::Instance()->SetDefaultFactory( cuda::kCudaPlatformId, PluginKind::kRng, cuda::kCuRandPlugin); } } // namespace stream_executor REGISTER_MODULE_INITIALIZER(register_curand, { stream_executor::initialize_curand(); });