diff options
author | 2018-05-02 17:57:27 -0700 | |
---|---|---|
committer | 2018-05-02 18:00:06 -0700 | |
commit | 8f0a90b711480c12716d1a3b1094cc8b34939f2d (patch) | |
tree | fc13cfb0c8bbd942a5cc46be8cd426f9f3a32b02 /tensorflow/core/kernels/fft_ops.cc | |
parent | 7833890a0da5226e4c409b1020155f1718c0edb2 (diff) |
Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
NumPy automatically upcasts to complex128 when computing FFTs, leading to issues like:
#10749
This change allows users to choose between 32-bit and 64-bit precision FFTs on CPU and GPU.
PiperOrigin-RevId: 195183206
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 78 |
1 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 661bf5fc5f..d7105a71bb 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -129,13 +129,23 @@ class FFTCPU : public FFTBase { auto device = ctx->eigen_device<CPUDevice>(); if (!IsReal()) { - auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>(); - // Compute the FFT using eigen. - auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + // Compute the FFT using Eigen. constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; - output.device(device) = - input.template fft<Eigen::BothParts, direction>(axes); + if (in.dtype() == DT_COMPLEX64) { + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>(); + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + output.device(device) = + input.template fft<Eigen::BothParts, direction>(axes); + } else { + DCHECK_EQ(DT_COMPLEX128, in.dtype()); + DCHECK_EQ(DT_COMPLEX128, out->dtype()); + auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>(); + auto output = out->flat_inner_dims<complex128, FFTRank + 1>(); + output.device(device) = + input.template fft<Eigen::BothParts, direction>(axes); + } } else { if (IsForward()) { auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>(); @@ -392,10 +402,16 @@ class FFTGPUBase : public FFTBase { } constexpr bool kInPlaceFft = false; + const bool is_complex128 = in.dtype() == DT_COMPLEX128; + // complex128 real FFT is not supported yet. + DCHECK(!IsReal() || !is_complex128); + const auto kFftType = IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R) - : (IsForward() ? se::fft::Type::kC2CForward - : se::fft::Type::kC2CInverse); + : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward) + : (is_complex128 ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse)); CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); auto plan = @@ -428,20 +444,42 @@ class FFTGPUBase : public FFTBase { input_shape.DebugString())); } } else { - auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data()); - auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast<int>(kFftType), - " in.shape=", input_shape.DebugString())); - if (!IsForward()) { - auto alpha = complex64(1.f / output_distance); + if (!is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data()); + auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data()); OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast<int>(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + float alpha = 1.f / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data()); + auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data()); + OP_REQUIRES( + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast<int>(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + double alpha = 1.0 / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } } } } |