diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_helpers.h')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_helpers.h | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h new file mode 100644 index 0000000000..2c5311cb3b --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_helpers.h @@ -0,0 +1,95 @@ +// Common helper functions used for dealing with CUDA API datatypes. +// +// These are typically placed here for use by multiple source components (for +// example, BLAS and executor components). + +#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ + +#include <stddef.h> +#include <complex> + +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cuda.h" + +namespace perftools { +namespace gputools { + +class Stream; +template <typename ElemT> +class DeviceMemory; + +namespace cuda { + +// Converts a const DeviceMemory reference to its underlying typed pointer in +// CUDA +// device memory. +template <typename T> +const T *CUDAMemory(const DeviceMemory<T> &mem) { + return static_cast<const T *>(mem.opaque()); +} + +// Converts a (non-const) DeviceMemory pointer reference to its underlying typed +// pointer in CUDA device device memory. +template <typename T> +T *CUDAMemoryMutable(DeviceMemory<T> *mem) { + return static_cast<T *>(mem->opaque()); +} + +CUstream AsCUDAStreamValue(Stream *stream); + +static_assert(sizeof(std::complex<float>) == sizeof(cuComplex), + "std::complex<float> and cuComplex should have the same size"); +static_assert(offsetof(cuComplex, x) == 0, + "The real part of cuComplex should appear first."); +static_assert(sizeof(std::complex<double>) == sizeof(cuDoubleComplex), + "std::complex<double> and cuDoubleComplex should have the same " + "size"); +static_assert(offsetof(cuDoubleComplex, x) == 0, + "The real part of cuDoubleComplex should appear first."); + +// Type traits to get CUDA complex types from std::complex<>. + +template <typename T> +struct CUDAComplexT { + typedef T type; +}; + +template <> +struct CUDAComplexT<std::complex<float>> { + typedef cuComplex type; +}; + +template <> +struct CUDAComplexT<std::complex<double>> { + typedef cuDoubleComplex type; +}; + +// Converts pointers of std::complex<> to pointers of +// cuComplex/cuDoubleComplex. No type conversion for non-complex types. + +template <typename T> +inline const typename CUDAComplexT<T>::type *CUDAComplex(const T *p) { + return reinterpret_cast<const typename CUDAComplexT<T>::type *>(p); +} + +template <typename T> +inline typename CUDAComplexT<T>::type *CUDAComplex(T *p) { + return reinterpret_cast<typename CUDAComplexT<T>::type *>(p); +} + +// Converts values of std::complex<float/double> to values of +// cuComplex/cuDoubleComplex. +inline cuComplex CUDAComplexValue(std::complex<float> val) { + return {val.real(), val.imag()}; +} + +inline cuDoubleComplex CUDAComplexValue(std::complex<double> val) { + return {val.real(), val.imag()}; +} + +} // namespace cuda +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ |