aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/fft.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-11 22:48:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 22:52:35 -0700
commita05c3ce52ffd4909cb9bf7c155805406b5b1fc06 (patch)
tree3193719801c2aee4bb2e967d59f582ce40dabf2b /tensorflow/stream_executor/fft.h
parent54137ffd39bfb86bc6f259ebaa3c85fecfbe7e93 (diff)
Add scratch allocator option for 1D, 2D, 3D, and batched cufft plan creation.
PiperOrigin-RevId: 161627209
Diffstat (limited to 'tensorflow/stream_executor/fft.h')
-rw-r--r--tensorflow/stream_executor/fft.h121
1 files changed, 89 insertions, 32 deletions
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h
index 6e921d142b..98cd77e206 100644
--- a/tensorflow/stream_executor/fft.h
+++ b/tensorflow/stream_executor/fft.h
@@ -54,12 +54,14 @@ namespace gputools {
class Stream;
template <typename ElemT>
class DeviceMemory;
+class ScratchAllocator;
namespace fft {
// Specifies FFT input and output types, and the direction.
// R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex.
enum class Type {
+ kInvalid,
kC2CForward,
kC2CInverse,
kC2R,
@@ -103,6 +105,21 @@ class FftSupport {
uint64 num_y, uint64 num_z,
Type type, bool in_place_fft) = 0;
+ // Creates a 1d FFT plan with scratch allocator.
+ virtual std::unique_ptr<Plan> Create1dPlanWithScratchAllocator(
+ Stream *stream, uint64 num_x, Type type, bool in_place_fft,
+ ScratchAllocator *scratch_allocator) = 0;
+
+ // Creates a 2d FFT plan with scratch allocator.
+ virtual std::unique_ptr<Plan> Create2dPlanWithScratchAllocator(
+ Stream *stream, uint64 num_x, uint64 num_y, Type type, bool in_place_fft,
+ ScratchAllocator *scratch_allocator) = 0;
+
+ // Creates a 3d FFT plan with scratch allocator.
+ virtual std::unique_ptr<Plan> Create3dPlanWithScratchAllocator(
+ Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, Type type,
+ bool in_place_fft, ScratchAllocator *scratch_allocator) = 0;
+
// Creates a batched FFT plan.
//
// stream: The GPU stream in which the FFT runs.
@@ -126,6 +143,30 @@ class FftSupport {
uint64 output_stride, uint64 output_distance, Type type,
bool in_place_fft, int batch_count) = 0;
+ // Creates a batched FFT plan with scratch allocator.
+ //
+ // stream: The GPU stream in which the FFT runs.
+ // rank: Dimensionality of the transform (1, 2, or 3).
+ // elem_count: Array of size rank, describing the size of each dimension.
+ // input_embed, output_embed:
+ // Pointer of size rank that indicates the storage dimensions
+ // of the input/output data in memory. If set to null_ptr all
+ // other advanced data layout parameters are ignored.
+ // input_stride: Indicates the distance (number of elements; same below)
+ // between two successive input elements.
+ // input_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the input data.
+ // output_stride: Indicates the distance between two successive output
+ // elements.
+ // output_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the output data.
+ virtual std::unique_ptr<Plan> CreateBatchedPlanWithScratchAllocator(
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+ uint64 output_stride, uint64 output_distance, Type type,
+ bool in_place_fft, int batch_count,
+ ScratchAllocator *scratch_allocator) = 0;
+
// Computes complex-to-complex FFT in the transform direction as specified
// by direction parameter.
virtual bool DoFft(Stream *stream, Plan *plan,
@@ -161,38 +202,54 @@ class FftSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// fft::FftSupport base class. Assumes that it's emitted somewhere inside the
// ::perftools::gputools namespace.
-#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
- std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
- fft::Type type, bool in_place_fft) \
- override; \
- std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \
- uint64 num_y, fft::Type type, \
- bool in_place_fft) override; \
- std::unique_ptr<fft::Plan> Create3dPlan( \
- Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \
- fft::Type type, bool in_place_fft) override; \
- std::unique_ptr<fft::Plan> CreateBatchedPlan( \
- Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \
- uint64 input_stride, uint64 input_distance, uint64 *output_embed, \
- uint64 output_stride, uint64 output_distance, fft::Type type, \
- bool in_place_fft, int batch_count) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<std::complex<float>> &input, \
- DeviceMemory<std::complex<float>> *output) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<std::complex<double>> &input, \
- DeviceMemory<std::complex<double>> *output) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<float> &input, \
- DeviceMemory<std::complex<float>> *output) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<double> &input, \
- DeviceMemory<std::complex<double>> *output) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<std::complex<float>> &input, \
- DeviceMemory<float> *output) override; \
- bool DoFft(Stream *stream, fft::Plan *plan, \
- const DeviceMemory<std::complex<double>> &input, \
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
+ std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
+ fft::Type type, bool in_place_fft) \
+ override; \
+ std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \
+ uint64 num_y, fft::Type type, \
+ bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> Create3dPlan( \
+ Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \
+ fft::Type type, bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> Create1dPlanWithScratchAllocator( \
+ Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft, \
+ ScratchAllocator *scratch_allocator) override; \
+ std::unique_ptr<fft::Plan> Create2dPlanWithScratchAllocator( \
+ Stream *stream, uint64 num_x, uint64 num_y, fft::Type type, \
+ bool in_place_fft, ScratchAllocator *scratch_allocator) override; \
+ std::unique_ptr<fft::Plan> Create3dPlanWithScratchAllocator( \
+ Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \
+ fft::Type type, bool in_place_fft, ScratchAllocator *scratch_allocator) \
+ override; \
+ std::unique_ptr<fft::Plan> CreateBatchedPlan( \
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed, \
+ uint64 output_stride, uint64 output_distance, fft::Type type, \
+ bool in_place_fft, int batch_count) override; \
+ std::unique_ptr<fft::Plan> CreateBatchedPlanWithScratchAllocator( \
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed, \
+ uint64 output_stride, uint64 output_distance, fft::Type type, \
+ bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \
+ override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<float> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<double> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<float> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
DeviceMemory<double> *output) override;
} // namespace fft