diff options
-rw-r--r-- | unsupported/Eigen/CXX11/Tensor | 4 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h | 28 |
2 files changed, 31 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index f7b94cee1..1d9f89864 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -64,6 +64,10 @@ typedef unsigned __int64 uint64_t; #if defined(__CUDACC__) #include <curand_kernel.h> #endif +#if __cplusplus >= 201103L +#include <atomic> +#include <unistd.h> +#endif #endif #include "src/Tensor/TensorMacros.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h index 28c6f7626..4f5767bc7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h @@ -42,7 +42,21 @@ static bool m_devicePropInitialized = false; static void initializeDeviceProp() { if (!m_devicePropInitialized) { - if (!m_devicePropInitialized) { + // Attempts to ensure proper behavior in the case of multiple threads + // calling this function simultaneously. This would be trivial to + // implement if we could use std::mutex, but unfortunately mutex don't + // compile with nvcc, so we resort to atomics and thread fences instead. + // Note that if the caller uses a compiler that doesn't support c++11 we + // can't ensure that the initialization is thread safe. +#if __cplusplus >= 201103L + static std::atomic<bool> first(true); + if (first.exchange(false)) { +#else + static bool first = true; + if (first) { + first = false; +#endif + // We're the first thread to reach this point. int num_devices; cudaError_t status = cudaGetDeviceCount(&num_devices); if (status != cudaSuccess) { @@ -63,7 +77,19 @@ static void initializeDeviceProp() { assert(status == cudaSuccess); } } + +#if __cplusplus >= 201103L + std::atomic_thread_fence(std::memory_order_release); +#endif m_devicePropInitialized = true; + } else { + // Wait for the other thread to inititialize the properties. + while (!m_devicePropInitialized) { +#if __cplusplus >= 201103L + std::atomic_thread_fence(std::memory_order_acquire); +#endif + sleep(1); + } } } } |