diff options
author | 2017-06-07 11:04:18 -0700 | |
---|---|---|
committer | 2017-06-07 11:08:29 -0700 | |
commit | 7b4c01794fbc2e6dc46e93a42416dd80929ce1e5 (patch) | |
tree | d8b498c28c20ef36e6fd261960d3fb3a2bdfd043 /tensorflow/core/kernels/fft_ops.cc | |
parent | fdb8e29354ce93afa8c2335a6287a59eb37d42fc (diff) |
Support numpy-style padding and slicing of tf.spectral.rfft/irfft to match the desired FFT length.
Fixes incorrect RFFT/IRFFT results when fft_length does not match the input dimension.
PiperOrigin-RevId: 158289991
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 50 |
1 files changed, 39 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 639f6a76de..b479956632 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -39,15 +39,15 @@ class FFTBase : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& in = ctx->input(0); - const TensorShape& shape = in.shape(); + const TensorShape& input_shape = in.shape(); const int fft_rank = Rank(); OP_REQUIRES( - ctx, shape.dims() >= fft_rank, + ctx, input_shape.dims() >= fft_rank, errors::InvalidArgument("Input must have rank of at least ", fft_rank, - " but got: ", shape.DebugString())); + " but got: ", input_shape.DebugString())); Tensor* out; - TensorShape output_shape = shape; + TensorShape output_shape = input_shape; uint64 fft_shape[3] = {0, 0, 0}; // In R2C or C2R mode, we use a second input to specify the FFT length @@ -57,13 +57,29 @@ class FFTBase : public OpKernel { OP_REQUIRES(ctx, fft_length.shape().dims() == 1 && fft_length.shape().dim_size(0) == fft_rank, - errors::InvalidArgument("fft_length must have shape [", + errors::InvalidArgument("fft_length must have shape [", fft_rank, "]")); auto fft_length_as_vec = fft_length.vec<int32>(); for (int i = 0; i < fft_rank; ++i) { fft_shape[i] = fft_length_as_vec(i); - uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0 + // Each input dimension must have length of at least fft_shape[i]. For + // IRFFTs, the inner-most input dimension must have length of at least + // fft_shape[i] / 2 + 1. + bool inner_most = (i == fft_rank - 1); + uint64 min_input_dim_length = + !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i]; + auto input_index = input_shape.dims() - fft_rank + i; + OP_REQUIRES( + ctx, + // We pass through empty tensors, so special case them here. + input_shape.dim_size(input_index) == 0 || + input_shape.dim_size(input_index) >= min_input_dim_length, + errors::InvalidArgument( + "Input dimension ", input_index, + " must have length of at least ", min_input_dim_length, + " but got: ", input_shape.dim_size(input_index))); + uint64 dim = IsForward() && inner_most && fft_shape[i] != 0 ? fft_shape[i] / 2 + 1 : fft_shape[i]; output_shape.set_dim(output_shape.dims() - fft_rank + i, dim); @@ -76,7 +92,7 @@ class FFTBase : public OpKernel { } OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out)); - if (shape.num_elements() == 0) { + if (input_shape.num_elements() == 0) { return; } @@ -120,20 +136,32 @@ class FFTCPU : public FFTBase { } else { if (IsForward()) { auto input = (Tensor(in)).flat_inner_dims<float, FFTRank + 1>(); + auto input_dims = input.dimensions(); + + // Slice input to fft_shape on its inner-most dimensions. + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes; + input_slice_sizes[0] = input_dims[0]; + TensorShape temp_shape{input_dims[0]}; + for (int i = 1; i <= FFTRank; ++i) { + input_slice_sizes[i] = fft_shape[i - 1]; + temp_shape.AddDim(fft_shape[i - 1]); + } + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); - Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices; + const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices; // Compute the full FFT using a temporary tensor. Tensor temp; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(), - in.shape(), &temp)); + temp_shape, &temp)); auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); full_fft.device(device) = - input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); + input.slice(zero_start_indices, input_slice_sizes) + .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); // Slice away the negative frequency components. output.device(device) = - full_fft.slice(startIndices, output.dimensions()); + full_fft.slice(zero_start_indices, output.dimensions()); } else { // TODO: reconstruct the full fft and take the inverse. ctx->CtxFailureWithWarning( |