aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/fft.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/fft.h')
-rw-r--r--tensorflow/stream_executor/fft.h187
1 files changed, 187 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h
new file mode 100644
index 0000000000..b47921d8f2
--- /dev/null
+++ b/tensorflow/stream_executor/fft.h
@@ -0,0 +1,187 @@
+// Exposes the family of FFT routines as pre-canned high performance calls for
+// use in conjunction with the StreamExecutor abstraction.
+//
+// Note that this interface is optionally supported by platforms; see
+// StreamExecutor::SupportsFft() for details.
+//
+// This abstraction makes it simple to entrain FFT operations on GPU data into
+// a Stream -- users typically will not use this API directly, but will use the
+// Stream builder methods to entrain these operations "under the hood". For
+// example:
+//
+// DeviceMemory<std::complex<float>> x =
+// stream_exec->AllocateArray<std::complex<float>>(1024);
+// DeviceMemory<std::complex<float>> y =
+// stream_exec->AllocateArray<std::complex<float>>(1024);
+// // ... populate x and y ...
+// Stream stream{stream_exec};
+// std::unique_ptr<Plan> plan =
+// stream_exec.AsFft()->Create1dPlan(&stream, 1024, Type::kC2CForward);
+// stream
+// .Init()
+// .ThenFft(plan.get(), x, &y)
+// .BlockHostUntilDone();
+//
+// By using stream operations in this manner the user can easily intermix custom
+// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned FFT
+// routines.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_FFT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_FFT_H_
+
+#include <complex>
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace fft {
+
+// Specifies FFT input and output types, and the direction.
+// R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex.
+enum class Type {
+ kC2CForward,
+ kC2CInverse,
+ kC2R,
+ kR2C,
+ kZ2ZForward,
+ kZ2ZInverse,
+ kZ2D,
+ kD2Z
+};
+
+// FFT plan class. Each FFT implementation should define a plan class that is
+// derived from this class. It does not provide any interface but serves
+// as a common type that is used to execute the plan.
+class Plan {
+ public:
+ virtual ~Plan() {}
+};
+
+// FFT support interface -- this can be derived from a GPU executor when the
+// underlying platform has an FFT library implementation available. See
+// StreamExecutor::AsFft().
+//
+// This support interface is not generally thread-safe; it is only thread-safe
+// for the CUDA platform (cuFFT) usage; host side FFT support is known
+// thread-compatible, but not thread-safe.
+class FftSupport {
+ public:
+ virtual ~FftSupport() {}
+
+ // Creates a 1d FFT plan.
+ virtual std::unique_ptr<Plan> Create1dPlan(Stream *stream, uint64 num_x,
+ Type type, bool in_place_fft) = 0;
+
+ // Creates a 2d FFT plan.
+ virtual std::unique_ptr<Plan> Create2dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, Type type,
+ bool in_place_fft) = 0;
+
+ // Creates a 3d FFT plan.
+ virtual std::unique_ptr<Plan> Create3dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, uint64 num_z,
+ Type type, bool in_place_fft) = 0;
+
+ // Creates a batched FFT plan.
+ //
+ // stream: The GPU stream in which the FFT runs.
+ // rank: Dimensionality of the transform (1, 2, or 3).
+ // elem_count: Array of size rank, describing the size of each dimension.
+ // input_embed, output_embed:
+ // Pointer of size rank that indicates the storage dimensions
+ // of the input/output data in memory. If set to null_ptr all
+ // other advanced data layout parameters are ignored.
+ // input_stride: Indicates the distance (number of elements; same below)
+ // between two successive input elements.
+ // input_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the input data.
+ // output_stride: Indicates the distance between two successive output
+ // elements.
+ // output_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the output data.
+ virtual std::unique_ptr<Plan> CreateBatchedPlan(
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+ uint64 output_stride, uint64 output_distance, Type type,
+ bool in_place_fft, int batch_count) = 0;
+
+ // Computes complex-to-complex FFT in the transform direction as specified
+ // by direction parameter.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<std::complex<float>> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<std::complex<double>> *output) = 0;
+
+ // Computes real-to-complex FFT in forward direction.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<float> &input,
+ DeviceMemory<std::complex<float>> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<double> &input,
+ DeviceMemory<std::complex<double>> *output) = 0;
+
+ // Computes complex-to-real FFT in inverse direction.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<float> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<double> *output) = 0;
+
+ protected:
+ FftSupport() {}
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(FftSupport);
+};
+
+// Macro used to quickly declare overrides for abstract virtuals in the
+// fft::FftSupport base class. Assumes that it's emitted somewhere inside the
+// ::perftools::gputools namespace.
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
+ std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
+ fft::Type type, bool in_place_fft) \
+ override; \
+ std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \
+ uint64 num_y, fft::Type type, \
+ bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> Create3dPlan( \
+ Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \
+ fft::Type type, bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> CreateBatchedPlan( \
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed, \
+ uint64 output_stride, uint64 output_distance, fft::Type type, \
+ bool in_place_fft, int batch_count) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<float> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<double> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<float> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
+ DeviceMemory<double> *output) override;
+
+} // namespace fft
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_