diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_fft.h')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_fft.h | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.h b/tensorflow/stream_executor/cuda/cuda_fft.h new file mode 100644 index 0000000000..2577c2952e --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_fft.h @@ -0,0 +1,95 @@ +// CUDA-specific support for FFT functionality -- this wraps the cuFFT library +// capabilities, and is only included into CUDA implementation code -- it will +// not introduce cuda headers into other code. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ +#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ + +#include "tensorflow/stream_executor/fft.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin_registry.h" +#include "third_party/gpus/cuda/include/cufft.h" + +namespace perftools { +namespace gputools { + +class Stream; + +namespace cuda { + +class CUDAExecutor; + +// Opaque and unique indentifier for the cuFFT plugin. +extern const PluginId kCuFftPlugin; + +class CUDAFftPlan : public fft::Plan { + public: + // Constructor creating 1d FFT plan. + CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, fft::Type type); + // Constructor creating 2d FFT plan. + CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y, fft::Type type); + // Constructor creating 3d FFT plan. + CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y, uint64 num_z, + fft::Type type); + // Constructor creating batched FFT plan. + CUDAFftPlan(CUDAExecutor *parent, int rank, uint64 *elem_count, + uint64 *input_embed, uint64 input_stride, uint64 input_distance, + uint64 *output_embed, uint64 output_stride, + uint64 output_distance, fft::Type type, int batch_count); + ~CUDAFftPlan() override; + + // Get FFT direction in cuFFT based on FFT type. + int GetFftDirection() const; + cufftHandle GetPlan() const { return plan_; } + + private: + CUDAExecutor *parent_; + cufftHandle plan_; + fft::Type fft_type_; +}; + +// FFT support for CUDA platform via cuFFT library. +// +// This satisfies the platform-agnostic FftSupport interface. +// +// Note that the cuFFT handle that this encapsulates is implicitly tied to the +// context (and, as a result, the device) that the parent CUDAExecutor is tied +// to. This simply happens as an artifact of creating the cuFFT handle when a +// CUDA context is active. +// +// Thread-safe. The CUDA context associated with all operations is the CUDA +// context of parent_, so all context is explicit. +class CUDAFft : public fft::FftSupport { + public: + explicit CUDAFft(CUDAExecutor *parent) : parent_(parent) {} + ~CUDAFft() override {} + + TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES + + private: + CUDAExecutor *parent_; + + // Two helper functions that execute dynload::cufftExec?2?. + + // This is for complex to complex FFT, when the direction is required. + template <typename FuncT, typename InputT, typename OutputT> + bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, + FuncT cufft_exec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output); + + // This is for complex to real or real to complex FFT, when the direction + // is implied. + template <typename FuncT, typename InputT, typename OutputT> + bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output); + + SE_DISALLOW_COPY_AND_ASSIGN(CUDAFft); +}; + +} // namespace cuda +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ |