aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/fft.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 10:32:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 10:37:06 -0800
commit4c0adf2c26345dd63e2b883317f7efb464862532 (patch)
treed3ce582a580347cbd3e854c33979a4bea9480e5f /tensorflow/stream_executor/fft.h
parentdddbe5f43a7a0688089cf4fc9472ca8893460b3d (diff)
Adds support in stream executor interface to update the scratch allocator used with a cuFFT plan. This enables plan reuse without requiring we keep the scratch allocation alive between executions.
PiperOrigin-RevId: 179939994
Diffstat (limited to 'tensorflow/stream_executor/fft.h')
-rw-r--r--tensorflow/stream_executor/fft.h12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h
index 408516a416..6b1728829a 100644
--- a/tensorflow/stream_executor/fft.h
+++ b/tensorflow/stream_executor/fft.h
@@ -167,6 +167,15 @@ class FftSupport {
bool in_place_fft, int batch_count,
ScratchAllocator *scratch_allocator) = 0;
+ // Updates the plan's work area with space allocated by a new scratch
+ // allocator. This facilitates plan reuse with scratch allocators.
+ //
+ // This requires that the plan was originally created using a scratch
+ // allocator, as otherwise scratch space will have been allocated internally
+ // by cuFFT.
+ virtual void UpdatePlanWithScratchAllocator(
+ Stream *stream, Plan *plan, ScratchAllocator *scratch_allocator) = 0;
+
// Computes complex-to-complex FFT in the transform direction as specified
// by direction parameter.
virtual bool DoFft(Stream *stream, Plan *plan,
@@ -233,6 +242,9 @@ class FftSupport {
uint64 output_stride, uint64 output_distance, fft::Type type, \
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \
override; \
+ void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan, \
+ ScratchAllocator *scratch_allocator) \
+ override; \
bool DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<float>> &input, \
DeviceMemory<std::complex<float>> *output) override; \