diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-07-11 22:48:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-11 22:52:35 -0700 |
commit | a05c3ce52ffd4909cb9bf7c155805406b5b1fc06 (patch) | |
tree | 3193719801c2aee4bb2e967d59f582ce40dabf2b /tensorflow/stream_executor/fft.h | |
parent | 54137ffd39bfb86bc6f259ebaa3c85fecfbe7e93 (diff) |
Add scratch allocator option for 1D, 2D, 3D, and batched cufft plan creation.
PiperOrigin-RevId: 161627209
Diffstat (limited to 'tensorflow/stream_executor/fft.h')
-rw-r--r-- | tensorflow/stream_executor/fft.h | 121 |
1 files changed, 89 insertions, 32 deletions
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h index 6e921d142b..98cd77e206 100644 --- a/tensorflow/stream_executor/fft.h +++ b/tensorflow/stream_executor/fft.h @@ -54,12 +54,14 @@ namespace gputools { class Stream; template <typename ElemT> class DeviceMemory; +class ScratchAllocator; namespace fft { // Specifies FFT input and output types, and the direction. // R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex. enum class Type { + kInvalid, kC2CForward, kC2CInverse, kC2R, @@ -103,6 +105,21 @@ class FftSupport { uint64 num_y, uint64 num_z, Type type, bool in_place_fft) = 0; + // Creates a 1d FFT plan with scratch allocator. + virtual std::unique_ptr<Plan> Create1dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, Type type, bool in_place_fft, + ScratchAllocator *scratch_allocator) = 0; + + // Creates a 2d FFT plan with scratch allocator. + virtual std::unique_ptr<Plan> Create2dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, uint64 num_y, Type type, bool in_place_fft, + ScratchAllocator *scratch_allocator) = 0; + + // Creates a 3d FFT plan with scratch allocator. + virtual std::unique_ptr<Plan> Create3dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, Type type, + bool in_place_fft, ScratchAllocator *scratch_allocator) = 0; + // Creates a batched FFT plan. // // stream: The GPU stream in which the FFT runs. @@ -126,6 +143,30 @@ class FftSupport { uint64 output_stride, uint64 output_distance, Type type, bool in_place_fft, int batch_count) = 0; + // Creates a batched FFT plan with scratch allocator. + // + // stream: The GPU stream in which the FFT runs. + // rank: Dimensionality of the transform (1, 2, or 3). + // elem_count: Array of size rank, describing the size of each dimension. + // input_embed, output_embed: + // Pointer of size rank that indicates the storage dimensions + // of the input/output data in memory. If set to null_ptr all + // other advanced data layout parameters are ignored. + // input_stride: Indicates the distance (number of elements; same below) + // between two successive input elements. + // input_distance: Indicates the distance between the first element of two + // consecutive signals in a batch of the input data. + // output_stride: Indicates the distance between two successive output + // elements. + // output_distance: Indicates the distance between the first element of two + // consecutive signals in a batch of the output data. + virtual std::unique_ptr<Plan> CreateBatchedPlanWithScratchAllocator( + Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, + uint64 input_stride, uint64 input_distance, uint64 *output_embed, + uint64 output_stride, uint64 output_distance, Type type, + bool in_place_fft, int batch_count, + ScratchAllocator *scratch_allocator) = 0; + // Computes complex-to-complex FFT in the transform direction as specified // by direction parameter. virtual bool DoFft(Stream *stream, Plan *plan, @@ -161,38 +202,54 @@ class FftSupport { // Macro used to quickly declare overrides for abstract virtuals in the // fft::FftSupport base class. Assumes that it's emitted somewhere inside the // ::perftools::gputools namespace. -#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ - std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \ - fft::Type type, bool in_place_fft) \ - override; \ - std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \ - uint64 num_y, fft::Type type, \ - bool in_place_fft) override; \ - std::unique_ptr<fft::Plan> Create3dPlan( \ - Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \ - fft::Type type, bool in_place_fft) override; \ - std::unique_ptr<fft::Plan> CreateBatchedPlan( \ - Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \ - uint64 input_stride, uint64 input_distance, uint64 *output_embed, \ - uint64 output_stride, uint64 output_distance, fft::Type type, \ - bool in_place_fft, int batch_count) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<std::complex<float>> &input, \ - DeviceMemory<std::complex<float>> *output) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<std::complex<double>> &input, \ - DeviceMemory<std::complex<double>> *output) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<float> &input, \ - DeviceMemory<std::complex<float>> *output) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<double> &input, \ - DeviceMemory<std::complex<double>> *output) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<std::complex<float>> &input, \ - DeviceMemory<float> *output) override; \ - bool DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<std::complex<double>> &input, \ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ + std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \ + fft::Type type, bool in_place_fft) \ + override; \ + std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \ + uint64 num_y, fft::Type type, \ + bool in_place_fft) override; \ + std::unique_ptr<fft::Plan> Create3dPlan( \ + Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \ + fft::Type type, bool in_place_fft) override; \ + std::unique_ptr<fft::Plan> Create1dPlanWithScratchAllocator( \ + Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft, \ + ScratchAllocator *scratch_allocator) override; \ + std::unique_ptr<fft::Plan> Create2dPlanWithScratchAllocator( \ + Stream *stream, uint64 num_x, uint64 num_y, fft::Type type, \ + bool in_place_fft, ScratchAllocator *scratch_allocator) override; \ + std::unique_ptr<fft::Plan> Create3dPlanWithScratchAllocator( \ + Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \ + fft::Type type, bool in_place_fft, ScratchAllocator *scratch_allocator) \ + override; \ + std::unique_ptr<fft::Plan> CreateBatchedPlan( \ + Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \ + uint64 input_stride, uint64 input_distance, uint64 *output_embed, \ + uint64 output_stride, uint64 output_distance, fft::Type type, \ + bool in_place_fft, int batch_count) override; \ + std::unique_ptr<fft::Plan> CreateBatchedPlanWithScratchAllocator( \ + Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \ + uint64 input_stride, uint64 input_distance, uint64 *output_embed, \ + uint64 output_stride, uint64 output_distance, fft::Type type, \ + bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \ + override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<float>> &input, \ + DeviceMemory<std::complex<float>> *output) override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<double>> &input, \ + DeviceMemory<std::complex<double>> *output) override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<float> &input, \ + DeviceMemory<std::complex<float>> *output) override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<double> &input, \ + DeviceMemory<std::complex<double>> *output) override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<float>> &input, \ + DeviceMemory<float> *output) override; \ + bool DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<double>> &input, \ DeviceMemory<double> *output) override; } // namespace fft |