diff options
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 50 | ||||
-rw-r--r-- | tensorflow/core/ops/spectral_ops.cc | 26 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/fft_ops_test.py | 69 | ||||
-rw-r--r-- | tensorflow/python/ops/spectral_ops.py | 52 |
4 files changed, 179 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 639f6a76de..b479956632 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -39,15 +39,15 @@ class FFTBase : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& in = ctx->input(0); - const TensorShape& shape = in.shape(); + const TensorShape& input_shape = in.shape(); const int fft_rank = Rank(); OP_REQUIRES( - ctx, shape.dims() >= fft_rank, + ctx, input_shape.dims() >= fft_rank, errors::InvalidArgument("Input must have rank of at least ", fft_rank, - " but got: ", shape.DebugString())); + " but got: ", input_shape.DebugString())); Tensor* out; - TensorShape output_shape = shape; + TensorShape output_shape = input_shape; uint64 fft_shape[3] = {0, 0, 0}; // In R2C or C2R mode, we use a second input to specify the FFT length @@ -57,13 +57,29 @@ class FFTBase : public OpKernel { OP_REQUIRES(ctx, fft_length.shape().dims() == 1 && fft_length.shape().dim_size(0) == fft_rank, - errors::InvalidArgument("fft_length must have shape [", + errors::InvalidArgument("fft_length must have shape [", fft_rank, "]")); auto fft_length_as_vec = fft_length.vec<int32>(); for (int i = 0; i < fft_rank; ++i) { fft_shape[i] = fft_length_as_vec(i); - uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0 + // Each input dimension must have length of at least fft_shape[i]. For + // IRFFTs, the inner-most input dimension must have length of at least + // fft_shape[i] / 2 + 1. + bool inner_most = (i == fft_rank - 1); + uint64 min_input_dim_length = + !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i]; + auto input_index = input_shape.dims() - fft_rank + i; + OP_REQUIRES( + ctx, + // We pass through empty tensors, so special case them here. + input_shape.dim_size(input_index) == 0 || + input_shape.dim_size(input_index) >= min_input_dim_length, + errors::InvalidArgument( + "Input dimension ", input_index, + " must have length of at least ", min_input_dim_length, + " but got: ", input_shape.dim_size(input_index))); + uint64 dim = IsForward() && inner_most && fft_shape[i] != 0 ? fft_shape[i] / 2 + 1 : fft_shape[i]; output_shape.set_dim(output_shape.dims() - fft_rank + i, dim); @@ -76,7 +92,7 @@ class FFTBase : public OpKernel { } OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out)); - if (shape.num_elements() == 0) { + if (input_shape.num_elements() == 0) { return; } @@ -120,20 +136,32 @@ class FFTCPU : public FFTBase { } else { if (IsForward()) { auto input = (Tensor(in)).flat_inner_dims<float, FFTRank + 1>(); + auto input_dims = input.dimensions(); + + // Slice input to fft_shape on its inner-most dimensions. + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes; + input_slice_sizes[0] = input_dims[0]; + TensorShape temp_shape{input_dims[0]}; + for (int i = 1; i <= FFTRank; ++i) { + input_slice_sizes[i] = fft_shape[i - 1]; + temp_shape.AddDim(fft_shape[i - 1]); + } + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); - Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices; + const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices; // Compute the full FFT using a temporary tensor. Tensor temp; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(), - in.shape(), &temp)); + temp_shape, &temp)); auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); full_fft.device(device) = - input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); + input.slice(zero_start_indices, input_slice_sizes) + .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); // Slice away the negative frequency components. output.device(device) = - full_fft.slice(startIndices, output.dimensions()); + full_fft.slice(zero_start_indices, output.dimensions()); } else { // TODO: reconstruct the full fft and take the inverse. ctx->CtxFailureWithWarning( diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 09b460fd14..592aaa25c3 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -201,6 +201,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A complex64 tensor of the same rank as `input`. The inner-most @@ -230,6 +234,10 @@ dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +than the corresponding dimension of `input`, the dimension is cropped. If it is +larger, the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A float32 tensor of the same rank as `input`. The inner-most @@ -257,6 +265,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the of `output`: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 2 @@ -287,6 +299,11 @@ from the size of the inner-most 2 dimensions of `input`. If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along each axis `IRFFT2D` is computed on, if `fft_length` (or +`fft_length / 2 + 1` for the inner-most dimension) is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 2 @@ -314,6 +331,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the of `output`: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 3 @@ -344,6 +365,11 @@ from the size of the inner-most 3 dimensions of `input`. If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along each axis `IRFFT3D` is computed on, if `fft_length` (or +`fft_length / 2 + 1` for the inner-most dimension) is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 3 diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index 84928bd2e1..2f3c5a6c33 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -22,8 +22,10 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops @@ -297,6 +299,38 @@ class RFFTOpsTest(BaseFFTOpsTest): self._CompareBackward(c2r.astype(np.complex64), rank, (size,) * rank, use_placeholder=True) + def testFftLength(self): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( + (size,) * dims) + c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), + 10).reshape((size,) * (dims - 1) + (inner_dim,)) + + # Test truncation (FFT size < dimensions). + fft_length = (size - 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) + + # Test padding (FFT size > dimensions). + fft_length = (size + 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) + def testRandom(self): np.random.seed(12345) @@ -326,10 +360,10 @@ class RFFTOpsTest(BaseFFTOpsTest): for dims in xrange(0, rank): x = np.zeros((1,) * dims).astype(np.complex64) with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): + ValueError, "Shape .* must have rank at least {}".format(rank)): self._tfFFT(x, rank) with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): + ValueError, "Shape .* must have rank at least {}".format(rank)): self._tfIFFT(x, rank) for dims in xrange(rank, rank + 2): x = np.zeros((1,) * rank) @@ -337,10 +371,10 @@ class RFFTOpsTest(BaseFFTOpsTest): # Test non-rank-1 fft_length produces an error. fft_length = np.zeros((1, 1)).astype(np.int32) with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): + "Shape .* must have rank 1"): self._tfFFT(x, rank, fft_length) with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): + "Shape .* must have rank 1"): self._tfIFFT(x, rank, fft_length) # Test wrong fft_length length. @@ -352,6 +386,29 @@ class RFFTOpsTest(BaseFFTOpsTest): ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): self._tfIFFT(x, rank, fft_length) + # Test that calling the kernel directly without padding to fft_length + # produces an error. + rffts_for_rank = {1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft], + 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d], + 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]} + rfft_fn, irfft_fn = rffts_for_rank[rank] + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "Input dimension .* must have length of at least 6 but got: 5"): + x = np.zeros((5,) * rank).astype(np.float32) + fft_length = [6] * rank + with self.test_session(): + rfft_fn(x, fft_length).eval() + # TODO(rjryan): Remove when CPU-based IRFFT is supported. + if test.is_gpu_available(cuda_only=True): + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "Input dimension .* must have length of at least .* but got: 3"): + x = np.zeros((3,) * rank).astype(np.complex64) + fft_length = [6] * rank + with self.test_session(): + irfft_fn(x, fft_length).eval() + def testGrad_Simple(self): if test.is_gpu_available(cuda_only=True): for rank in VALID_FFT_RANKS: @@ -359,9 +416,7 @@ class RFFTOpsTest(BaseFFTOpsTest): if rank == 3: continue for dims in xrange(rank, rank + 2): - for size in ( - 5, - 6,): + for size in (5, 6): re = np.ones(shape=(size,) * dims, dtype=np.float32) im = -np.ones(shape=(size,) * dims, dtype=np.float32) self._checkGradReal(self._tfFFTForRank(rank), re, use_gpu=True) diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 95a2806330..47ff7018f2 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -33,6 +33,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import tensor_util as _tensor_util from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import math_ops as _math_ops @@ -70,6 +71,52 @@ def _infer_fft_length_for_irfft(input_tensor, fft_rank): return _ops.convert_to_tensor(fft_length, _dtypes.int32) +def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): + """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" + fft_shape = _tensor_util.constant_value_as_shape(fft_length) + + # Edge case: skip padding empty tensors. + if (input_tensor.shape.ndims is not None and + any(dim.value == 0 for dim in input_tensor.shape)): + return input_tensor + + # If we know the shapes ahead of time, we can either skip or pre-compute the + # appropriate paddings. Otherwise, fall back to computing paddings in + # TensorFlow. + if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: + # Slice the last FFT-rank dimensions from input_tensor's shape. + input_fft_shape = input_tensor.shape[-fft_shape.ndims:] + + if input_fft_shape.is_fully_defined(): + # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. + if is_reverse: + fft_shape = fft_shape[:-1].concatenate(fft_shape[-1].value // 2 + 1) + + paddings = [[0, max(fft_dim.value - input_dim.value, 0)] + for fft_dim, input_dim in zip(fft_shape, input_fft_shape)] + if any(pad > 0 for _, pad in paddings): + outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - + fft_shape.ndims), 0) + return _array_ops.pad(input_tensor, outer_paddings + paddings) + return input_tensor + + # If we can't determine the paddings ahead of time, then we have to pad. If + # the paddings end up as zero, tf.pad has a special-case that does no work. + input_rank = _array_ops.rank(input_tensor) + input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] + outer_dims = _math_ops.maximum(0, input_rank - fft_rank) + outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) + # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. + if is_reverse: + fft_length = _array_ops.concat([fft_length[:-1], + fft_length[-1:] // 2 + 1], 0) + fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) + paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) + paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings], + axis=1) + return _array_ops.pad(input_tensor, paddings) + + def _rfft_wrapper(fft_fn, fft_rank, default_name): """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" @@ -77,10 +124,12 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name): with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32) + input_tensor.shape.with_rank_at_least(fft_rank) if fft_length is None: fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) + input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) return fft_fn(input_tensor, fft_length, name) _rfft.__doc__ = fft_fn.__doc__ return _rfft @@ -93,10 +142,13 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name): with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64) + input_tensor.shape.with_rank_at_least(fft_rank) if fft_length is None: fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) + input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, + is_reverse=True) return ifft_fn(input_tensor, fft_length, name) _irfft.__doc__ = ifft_fn.__doc__ return _irfft |