diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-22 10:32:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-22 10:37:06 -0800 |
commit | 4c0adf2c26345dd63e2b883317f7efb464862532 (patch) | |
tree | d3ce582a580347cbd3e854c33979a4bea9480e5f /tensorflow/stream_executor/fft.h | |
parent | dddbe5f43a7a0688089cf4fc9472ca8893460b3d (diff) |
Adds support in stream executor interface to update the scratch allocator used with a cuFFT plan. This enables plan reuse without requiring we keep the scratch allocation alive between executions.
PiperOrigin-RevId: 179939994
Diffstat (limited to 'tensorflow/stream_executor/fft.h')
-rw-r--r-- | tensorflow/stream_executor/fft.h | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h index 408516a416..6b1728829a 100644 --- a/tensorflow/stream_executor/fft.h +++ b/tensorflow/stream_executor/fft.h @@ -167,6 +167,15 @@ class FftSupport { bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Updates the plan's work area with space allocated by a new scratch + // allocator. This facilitates plan reuse with scratch allocators. + // + // This requires that the plan was originally created using a scratch + // allocator, as otherwise scratch space will have been allocated internally + // by cuFFT. + virtual void UpdatePlanWithScratchAllocator( + Stream *stream, Plan *plan, ScratchAllocator *scratch_allocator) = 0; + // Computes complex-to-complex FFT in the transform direction as specified // by direction parameter. virtual bool DoFft(Stream *stream, Plan *plan, @@ -233,6 +242,9 @@ class FftSupport { uint64 output_stride, uint64 output_distance, fft::Type type, \ bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \ override; \ + void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan, \ + ScratchAllocator *scratch_allocator) \ + override; \ bool DoFft(Stream *stream, fft::Plan *plan, \ const DeviceMemory<std::complex<float>> &input, \ DeviceMemory<std::complex<float>> *output) override; \ |