From e6297741c9d5e6106b6fa4876afac9571e038161 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 7 Jul 2015 17:40:49 -0700 Subject: Added support for generation of random complex numbers on CUDA devices --- .../Eigen/CXX11/src/Tensor/TensorFunctors.h | 104 +++++++++++++++++++++ 1 file changed, 104 insertions(+) (limited to 'unsupported/Eigen/CXX11/src') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 33e8c01c2..14ffd5c93 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -387,6 +387,58 @@ template <> class UniformRandomGenerator { mutable curandStatePhilox4_32_10_t m_state; }; +template <> class UniformRandomGenerator > { + public: + static const bool PacketAccess = false; + + __device__ UniformRandomGenerator(bool deterministic = true) : m_deterministic(deterministic) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + __device__ UniformRandomGenerator(const UniformRandomGenerator& other) { + m_deterministic = other.m_deterministic; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = m_deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + template + __device__ std::complex operator()(Index, Index = 0) const { + float4 vals = curand_uniform4(&m_state); + return std::complex(vals.x, vals.y); + } + + private: + bool m_deterministic; + mutable curandStatePhilox4_32_10_t m_state; +}; + +template <> class UniformRandomGenerator > { + public: + static const bool PacketAccess = false; + + __device__ UniformRandomGenerator(bool deterministic = true) : m_deterministic(deterministic) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + __device__ UniformRandomGenerator(const UniformRandomGenerator& other) { + m_deterministic = other.m_deterministic; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = m_deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + template + __device__ std::complex operator()(Index, Index = 0) const { + double2 vals = curand_uniform2_double(&m_state); + return std::complex(vals.x, vals.y); + } + + private: + bool m_deterministic; + mutable curandStatePhilox4_32_10_t m_state; +}; + #endif @@ -489,6 +541,58 @@ template <> class NormalRandomGenerator { mutable curandStatePhilox4_32_10_t m_state; }; +template <> class NormalRandomGenerator > { + public: + static const bool PacketAccess = false; + + __device__ NormalRandomGenerator(bool deterministic = true) : m_deterministic(deterministic) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + __device__ NormalRandomGenerator(const NormalRandomGenerator& other) { + m_deterministic = other.m_deterministic; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = m_deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + template + __device__ std::complex operator()(Index, Index = 0) const { + float4 vals = curand_normal4(&m_state); + return std::complex(vals.x, vals.y); + } + + private: + bool m_deterministic; + mutable curandStatePhilox4_32_10_t m_state; +}; + +template <> class NormalRandomGenerator > { + public: + static const bool PacketAccess = false; + + __device__ NormalRandomGenerator(bool deterministic = true) : m_deterministic(deterministic) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + __device__ NormalRandomGenerator(const NormalRandomGenerator& other) { + m_deterministic = other.m_deterministic; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int seed = m_deterministic ? 0 : get_random_seed(); + curand_init(seed, tid, 0, &m_state); + } + template + __device__ std::complex operator()(Index, Index = 0) const { + double2 vals = curand_normal2_double(&m_state); + return std::complex(vals.x, vals.y); + } + + private: + bool m_deterministic; + mutable curandStatePhilox4_32_10_t m_state; +}; + #else template class NormalRandomGenerator { -- cgit v1.2.3