aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fft_ops.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-05-02 17:57:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 18:00:06 -0700
commit8f0a90b711480c12716d1a3b1094cc8b34939f2d (patch)
treefc13cfb0c8bbd942a5cc46be8cd426f9f3a32b02 /tensorflow/core/kernels/fft_ops.cc
parent7833890a0da5226e4c409b1020155f1718c0edb2 (diff)
Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
NumPy automatically upcasts to complex128 when computing FFTs, leading to issues like: #10749 This change allows users to choose between 32-bit and 64-bit precision FFTs on CPU and GPU. PiperOrigin-RevId: 195183206
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r--tensorflow/core/kernels/fft_ops.cc78
1 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 661bf5fc5f..d7105a71bb 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -129,13 +129,23 @@ class FFTCPU : public FFTBase {
auto device = ctx->eigen_device<CPUDevice>();
if (!IsReal()) {
- auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
- // Compute the FFT using eigen.
- auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
+ // Compute the FFT using Eigen.
constexpr auto direction =
Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
- output.device(device) =
- input.template fft<Eigen::BothParts, direction>(axes);
+ if (in.dtype() == DT_COMPLEX64) {
+ DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+ auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
+ auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
+ output.device(device) =
+ input.template fft<Eigen::BothParts, direction>(axes);
+ } else {
+ DCHECK_EQ(DT_COMPLEX128, in.dtype());
+ DCHECK_EQ(DT_COMPLEX128, out->dtype());
+ auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
+ auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
+ output.device(device) =
+ input.template fft<Eigen::BothParts, direction>(axes);
+ }
} else {
if (IsForward()) {
auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
@@ -392,10 +402,16 @@ class FFTGPUBase : public FFTBase {
}
constexpr bool kInPlaceFft = false;
+ const bool is_complex128 = in.dtype() == DT_COMPLEX128;
+ // complex128 real FFT is not supported yet.
+ DCHECK(!IsReal() || !is_complex128);
+
const auto kFftType =
IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
- : (IsForward() ? se::fft::Type::kC2CForward
- : se::fft::Type::kC2CInverse);
+ : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
+ : se::fft::Type::kC2CForward)
+ : (is_complex128 ? se::fft::Type::kZ2ZInverse
+ : se::fft::Type::kC2CInverse));
CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
auto plan =
@@ -428,20 +444,42 @@ class FFTGPUBase : public FFTBase {
input_shape.DebugString()));
}
} else {
- auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
- auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
- OP_REQUIRES(
- ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
- errors::Internal("fft failed : type=", static_cast<int>(kFftType),
- " in.shape=", input_shape.DebugString()));
- if (!IsForward()) {
- auto alpha = complex64(1.f / output_distance);
+ if (!is_complex128) {
+ DCHECK_EQ(in.dtype(), DT_COMPLEX64);
+ DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+ auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
+ auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
OP_REQUIRES(
- ctx,
- stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
- .ok(),
- errors::Internal("BlasScal failed : in.shape=",
- input_shape.DebugString()));
+ ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
+ errors::Internal("fft failed : type=", static_cast<int>(kFftType),
+ " in.shape=", input_shape.DebugString()));
+ if (!IsForward()) {
+ float alpha = 1.f / output_distance;
+ OP_REQUIRES(
+ ctx,
+ stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
+ .ok(),
+ errors::Internal("BlasScal failed : in.shape=",
+ input_shape.DebugString()));
+ }
+ } else {
+ DCHECK_EQ(in.dtype(), DT_COMPLEX128);
+ DCHECK_EQ(out->dtype(), DT_COMPLEX128);
+ auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
+ auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
+ OP_REQUIRES(
+ ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
+ errors::Internal("fft failed : type=", static_cast<int>(kFftType),
+ " in.shape=", input_shape.DebugString()));
+ if (!IsForward()) {
+ double alpha = 1.0 / output_distance;
+ OP_REQUIRES(
+ ctx,
+ stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
+ .ok(),
+ errors::Internal("BlasScal failed : in.shape=",
+ input_shape.DebugString()));
+ }
}
}
}