diff options
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 63 |
1 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index b479956632..593fa487c9 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -17,7 +17,6 @@ limitations under the License. // See docs in ../ops/spectral_ops.cc. -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" @@ -163,9 +163,58 @@ class FFTCPU : public FFTBase { output.device(device) = full_fft.slice(zero_start_indices, output.dimensions()); } else { - // TODO: reconstruct the full fft and take the inverse. - ctx->CtxFailureWithWarning( - errors::Unimplemented("IRFFT is not implemented as a CPU kernel")); + // Reconstruct the full fft and take the inverse. + auto input = ((Tensor)in).flat_inner_dims<complex64, FFTRank + 1>(); + auto output = out->flat_inner_dims<float, FFTRank + 1>(); + + auto sizes = input.dimensions(); + + // Calculate the shape of full-fft temporary tensor. + TensorShape fullShape; + fullShape.AddDim(sizes[0]); + for (auto i = 1; i <= FFTRank; i++) { + fullShape.AddDim(fft_shape[i - 1]); + } + + Tensor temp; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(), + fullShape, &temp)); + auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); + + // Calculate the starting point and range of the source of + // negative frequency part. + auto negSizes = input.dimensions(); + negSizes[FFTRank] = fft_shape[FFTRank - 1] - sizes[FFTRank]; + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> negTargetIndices; + negTargetIndices[FFTRank] = sizes[FFTRank]; + + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices, + negStartIndices; + negStartIndices[FFTRank] = 1; + + full_fft.slice(startIndices, sizes) = input.slice(startIndices, sizes); + + // First, conduct FFT on outer dimensions. + auto outerAxes = Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); + full_fft = full_fft.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>( + outerAxes); + + // Reconstruct the full fft by appending reversed and conjugated + // spectrum as the negative frequency part. + Eigen::array<bool, FFTRank + 1> reversedAxis; + for (auto i = 0; i <= FFTRank; i++) { + reversedAxis[i] = i == FFTRank; + } + + full_fft.slice(negTargetIndices, negSizes) = + full_fft.slice(negStartIndices, negSizes) + .reverse(reversedAxis) + .conjugate(); + + auto innerAxis = Eigen::array<int, 1>{FFTRank}; + output.device(device) = + full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>( + innerAxis); } } } @@ -194,10 +243,16 @@ REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL), FFTCPU<true, true, 1>); +REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, true, 1>); REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL), FFTCPU<true, true, 2>); +REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, true, 2>); REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), FFTCPU<true, true, 3>); +REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, true, 3>); #undef FFT_LABEL |