aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fft_ops.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-06-07 11:04:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 11:08:29 -0700
commit7b4c01794fbc2e6dc46e93a42416dd80929ce1e5 (patch)
treed8b498c28c20ef36e6fd261960d3fb3a2bdfd043 /tensorflow/core/kernels/fft_ops.cc
parentfdb8e29354ce93afa8c2335a6287a59eb37d42fc (diff)
Support numpy-style padding and slicing of tf.spectral.rfft/irfft to match the desired FFT length.
Fixes incorrect RFFT/IRFFT results when fft_length does not match the input dimension. PiperOrigin-RevId: 158289991
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r--tensorflow/core/kernels/fft_ops.cc50
1 files changed, 39 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 639f6a76de..b479956632 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -39,15 +39,15 @@ class FFTBase : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& in = ctx->input(0);
- const TensorShape& shape = in.shape();
+ const TensorShape& input_shape = in.shape();
const int fft_rank = Rank();
OP_REQUIRES(
- ctx, shape.dims() >= fft_rank,
+ ctx, input_shape.dims() >= fft_rank,
errors::InvalidArgument("Input must have rank of at least ", fft_rank,
- " but got: ", shape.DebugString()));
+ " but got: ", input_shape.DebugString()));
Tensor* out;
- TensorShape output_shape = shape;
+ TensorShape output_shape = input_shape;
uint64 fft_shape[3] = {0, 0, 0};
// In R2C or C2R mode, we use a second input to specify the FFT length
@@ -57,13 +57,29 @@ class FFTBase : public OpKernel {
OP_REQUIRES(ctx,
fft_length.shape().dims() == 1 &&
fft_length.shape().dim_size(0) == fft_rank,
- errors::InvalidArgument("fft_length must have shape [",
+ errors::InvalidArgument("fft_length must have shape [",
fft_rank, "]"));
auto fft_length_as_vec = fft_length.vec<int32>();
for (int i = 0; i < fft_rank; ++i) {
fft_shape[i] = fft_length_as_vec(i);
- uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0
+ // Each input dimension must have length of at least fft_shape[i]. For
+ // IRFFTs, the inner-most input dimension must have length of at least
+ // fft_shape[i] / 2 + 1.
+ bool inner_most = (i == fft_rank - 1);
+ uint64 min_input_dim_length =
+ !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ auto input_index = input_shape.dims() - fft_rank + i;
+ OP_REQUIRES(
+ ctx,
+ // We pass through empty tensors, so special case them here.
+ input_shape.dim_size(input_index) == 0 ||
+ input_shape.dim_size(input_index) >= min_input_dim_length,
+ errors::InvalidArgument(
+ "Input dimension ", input_index,
+ " must have length of at least ", min_input_dim_length,
+ " but got: ", input_shape.dim_size(input_index)));
+ uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
? fft_shape[i] / 2 + 1
: fft_shape[i];
output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
@@ -76,7 +92,7 @@ class FFTBase : public OpKernel {
}
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
- if (shape.num_elements() == 0) {
+ if (input_shape.num_elements() == 0) {
return;
}
@@ -120,20 +136,32 @@ class FFTCPU : public FFTBase {
} else {
if (IsForward()) {
auto input = (Tensor(in)).flat_inner_dims<float, FFTRank + 1>();
+ auto input_dims = input.dimensions();
+
+ // Slice input to fft_shape on its inner-most dimensions.
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
+ input_slice_sizes[0] = input_dims[0];
+ TensorShape temp_shape{input_dims[0]};
+ for (int i = 1; i <= FFTRank; ++i) {
+ input_slice_sizes[i] = fft_shape[i - 1];
+ temp_shape.AddDim(fft_shape[i - 1]);
+ }
+
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
- Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices;
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
// Compute the full FFT using a temporary tensor.
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
- in.shape(), &temp));
+ temp_shape, &temp));
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
full_fft.device(device) =
- input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+ input.slice(zero_start_indices, input_slice_sizes)
+ .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
// Slice away the negative frequency components.
output.device(device) =
- full_fft.slice(startIndices, output.dimensions());
+ full_fft.slice(zero_start_indices, output.dimensions());
} else {
// TODO: reconstruct the full fft and take the inverse.
ctx->CtxFailureWithWarning(