aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fft_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r--tensorflow/core/kernels/fft_ops.cc63
1 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index b479956632..593fa487c9 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -17,7 +17,6 @@ limitations under the License.
// See docs in ../ops/spectral_ops.cc.
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -26,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
@@ -163,9 +163,58 @@ class FFTCPU : public FFTBase {
output.device(device) =
full_fft.slice(zero_start_indices, output.dimensions());
} else {
- // TODO: reconstruct the full fft and take the inverse.
- ctx->CtxFailureWithWarning(
- errors::Unimplemented("IRFFT is not implemented as a CPU kernel"));
+ // Reconstruct the full fft and take the inverse.
+ auto input = ((Tensor)in).flat_inner_dims<complex64, FFTRank + 1>();
+ auto output = out->flat_inner_dims<float, FFTRank + 1>();
+
+ auto sizes = input.dimensions();
+
+ // Calculate the shape of full-fft temporary tensor.
+ TensorShape fullShape;
+ fullShape.AddDim(sizes[0]);
+ for (auto i = 1; i <= FFTRank; i++) {
+ fullShape.AddDim(fft_shape[i - 1]);
+ }
+
+ Tensor temp;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
+ fullShape, &temp));
+ auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+
+ // Calculate the starting point and range of the source of
+ // negative frequency part.
+ auto negSizes = input.dimensions();
+ negSizes[FFTRank] = fft_shape[FFTRank - 1] - sizes[FFTRank];
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> negTargetIndices;
+ negTargetIndices[FFTRank] = sizes[FFTRank];
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices,
+ negStartIndices;
+ negStartIndices[FFTRank] = 1;
+
+ full_fft.slice(startIndices, sizes) = input.slice(startIndices, sizes);
+
+ // First, conduct FFT on outer dimensions.
+ auto outerAxes = Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
+ full_fft = full_fft.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
+ outerAxes);
+
+ // Reconstruct the full fft by appending reversed and conjugated
+ // spectrum as the negative frequency part.
+ Eigen::array<bool, FFTRank + 1> reversedAxis;
+ for (auto i = 0; i <= FFTRank; i++) {
+ reversedAxis[i] = i == FFTRank;
+ }
+
+ full_fft.slice(negTargetIndices, negSizes) =
+ full_fft.slice(negStartIndices, negSizes)
+ .reverse(reversedAxis)
+ .conjugate();
+
+ auto innerAxis = Eigen::array<int, 1>{FFTRank};
+ output.device(device) =
+ full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
+ innerAxis);
}
}
}
@@ -194,10 +243,16 @@ REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
FFTCPU<true, true, 1>);
+REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
+ FFTCPU<false, true, 1>);
REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
FFTCPU<true, true, 2>);
+REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
+ FFTCPU<false, true, 2>);
REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
FFTCPU<true, true, 3>);
+REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
+ FFTCPU<false, true, 3>);
#undef FFT_LABEL