aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/fft_ops.cc50
-rw-r--r--tensorflow/core/ops/spectral_ops.cc26
-rw-r--r--tensorflow/python/kernel_tests/fft_ops_test.py69
-rw-r--r--tensorflow/python/ops/spectral_ops.py52
4 files changed, 179 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 639f6a76de..b479956632 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -39,15 +39,15 @@ class FFTBase : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& in = ctx->input(0);
- const TensorShape& shape = in.shape();
+ const TensorShape& input_shape = in.shape();
const int fft_rank = Rank();
OP_REQUIRES(
- ctx, shape.dims() >= fft_rank,
+ ctx, input_shape.dims() >= fft_rank,
errors::InvalidArgument("Input must have rank of at least ", fft_rank,
- " but got: ", shape.DebugString()));
+ " but got: ", input_shape.DebugString()));
Tensor* out;
- TensorShape output_shape = shape;
+ TensorShape output_shape = input_shape;
uint64 fft_shape[3] = {0, 0, 0};
// In R2C or C2R mode, we use a second input to specify the FFT length
@@ -57,13 +57,29 @@ class FFTBase : public OpKernel {
OP_REQUIRES(ctx,
fft_length.shape().dims() == 1 &&
fft_length.shape().dim_size(0) == fft_rank,
- errors::InvalidArgument("fft_length must have shape [",
+ errors::InvalidArgument("fft_length must have shape [",
fft_rank, "]"));
auto fft_length_as_vec = fft_length.vec<int32>();
for (int i = 0; i < fft_rank; ++i) {
fft_shape[i] = fft_length_as_vec(i);
- uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0
+ // Each input dimension must have length of at least fft_shape[i]. For
+ // IRFFTs, the inner-most input dimension must have length of at least
+ // fft_shape[i] / 2 + 1.
+ bool inner_most = (i == fft_rank - 1);
+ uint64 min_input_dim_length =
+ !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ auto input_index = input_shape.dims() - fft_rank + i;
+ OP_REQUIRES(
+ ctx,
+ // We pass through empty tensors, so special case them here.
+ input_shape.dim_size(input_index) == 0 ||
+ input_shape.dim_size(input_index) >= min_input_dim_length,
+ errors::InvalidArgument(
+ "Input dimension ", input_index,
+ " must have length of at least ", min_input_dim_length,
+ " but got: ", input_shape.dim_size(input_index)));
+ uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
? fft_shape[i] / 2 + 1
: fft_shape[i];
output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
@@ -76,7 +92,7 @@ class FFTBase : public OpKernel {
}
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
- if (shape.num_elements() == 0) {
+ if (input_shape.num_elements() == 0) {
return;
}
@@ -120,20 +136,32 @@ class FFTCPU : public FFTBase {
} else {
if (IsForward()) {
auto input = (Tensor(in)).flat_inner_dims<float, FFTRank + 1>();
+ auto input_dims = input.dimensions();
+
+ // Slice input to fft_shape on its inner-most dimensions.
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
+ input_slice_sizes[0] = input_dims[0];
+ TensorShape temp_shape{input_dims[0]};
+ for (int i = 1; i <= FFTRank; ++i) {
+ input_slice_sizes[i] = fft_shape[i - 1];
+ temp_shape.AddDim(fft_shape[i - 1]);
+ }
+
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
- Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices;
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
// Compute the full FFT using a temporary tensor.
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
- in.shape(), &temp));
+ temp_shape, &temp));
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
full_fft.device(device) =
- input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+ input.slice(zero_start_indices, input_slice_sizes)
+ .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
// Slice away the negative frequency components.
output.device(device) =
- full_fft.slice(startIndices, output.dimensions());
+ full_fft.slice(zero_start_indices, output.dimensions());
} else {
// TODO: reconstruct the full fft and take the inverse.
ctx->CtxFailureWithWarning(
diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc
index 09b460fd14..592aaa25c3 100644
--- a/tensorflow/core/ops/spectral_ops.cc
+++ b/tensorflow/core/ops/spectral_ops.cc
@@ -201,6 +201,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the
`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,
followed by the `fft_length / 2` positive-frequency terms.
+Along the axis `RFFT` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [1]. The FFT length.
output: A complex64 tensor of the same rank as `input`. The inner-most
@@ -230,6 +234,10 @@ dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to
compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
+than the corresponding dimension of `input`, the dimension is cropped. If it is
+larger, the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [1]. The FFT length.
output: A float32 tensor of the same rank as `input`. The inner-most
@@ -257,6 +265,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
+Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
output: A complex64 tensor of the same rank as `input`. The inner-most 2
@@ -287,6 +299,11 @@ from the size of the inner-most 2 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along each axis `IRFFT2D` is computed on, if `fft_length` (or
+`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
output: A float32 tensor of the same rank as `input`. The inner-most 2
@@ -314,6 +331,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
+Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
output: A complex64 tensor of the same rank as `input`. The inner-most 3
@@ -344,6 +365,11 @@ from the size of the inner-most 3 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along each axis `IRFFT3D` is computed on, if `fft_length` (or
+`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
output: A float32 tensor of the same rank as `input`. The inner-most 3
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index 84928bd2e1..2f3c5a6c33 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -22,8 +22,10 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops
@@ -297,6 +299,38 @@ class RFFTOpsTest(BaseFFTOpsTest):
self._CompareBackward(c2r.astype(np.complex64), rank, (size,) * rank,
use_placeholder=True)
+ def testFftLength(self):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ for size in (5, 6):
+ inner_dim = size // 2 + 1
+ r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
+ (size,) * dims)
+ c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
+ 10).reshape((size,) * (dims - 1) + (inner_dim,))
+
+ # Test truncation (FFT size < dimensions).
+ fft_length = (size - 2,) * rank
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
+
+ # Confirm it works with unknown shapes as well.
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length,
+ use_placeholder=True)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
+ use_placeholder=True)
+
+ # Test padding (FFT size > dimensions).
+ fft_length = (size + 2,) * rank
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
+
+ # Confirm it works with unknown shapes as well.
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length,
+ use_placeholder=True)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
+ use_placeholder=True)
+
def testRandom(self):
np.random.seed(12345)
@@ -326,10 +360,10 @@ class RFFTOpsTest(BaseFFTOpsTest):
for dims in xrange(0, rank):
x = np.zeros((1,) * dims).astype(np.complex64)
with self.assertRaisesWithPredicateMatch(
- ValueError, "Shape must be .*rank {}.*".format(rank)):
+ ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfFFT(x, rank)
with self.assertRaisesWithPredicateMatch(
- ValueError, "Shape must be .*rank {}.*".format(rank)):
+ ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfIFFT(x, rank)
for dims in xrange(rank, rank + 2):
x = np.zeros((1,) * rank)
@@ -337,10 +371,10 @@ class RFFTOpsTest(BaseFFTOpsTest):
# Test non-rank-1 fft_length produces an error.
fft_length = np.zeros((1, 1)).astype(np.int32)
with self.assertRaisesWithPredicateMatch(ValueError,
- "Shape must be .*rank 1"):
+ "Shape .* must have rank 1"):
self._tfFFT(x, rank, fft_length)
with self.assertRaisesWithPredicateMatch(ValueError,
- "Shape must be .*rank 1"):
+ "Shape .* must have rank 1"):
self._tfIFFT(x, rank, fft_length)
# Test wrong fft_length length.
@@ -352,6 +386,29 @@ class RFFTOpsTest(BaseFFTOpsTest):
ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
self._tfIFFT(x, rank, fft_length)
+ # Test that calling the kernel directly without padding to fft_length
+ # produces an error.
+ rffts_for_rank = {1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
+ 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
+ 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]}
+ rfft_fn, irfft_fn = rffts_for_rank[rank]
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "Input dimension .* must have length of at least 6 but got: 5"):
+ x = np.zeros((5,) * rank).astype(np.float32)
+ fft_length = [6] * rank
+ with self.test_session():
+ rfft_fn(x, fft_length).eval()
+ # TODO(rjryan): Remove when CPU-based IRFFT is supported.
+ if test.is_gpu_available(cuda_only=True):
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "Input dimension .* must have length of at least .* but got: 3"):
+ x = np.zeros((3,) * rank).astype(np.complex64)
+ fft_length = [6] * rank
+ with self.test_session():
+ irfft_fn(x, fft_length).eval()
+
def testGrad_Simple(self):
if test.is_gpu_available(cuda_only=True):
for rank in VALID_FFT_RANKS:
@@ -359,9 +416,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
- for size in (
- 5,
- 6,):
+ for size in (5, 6):
re = np.ones(shape=(size,) * dims, dtype=np.float32)
im = -np.ones(shape=(size,) * dims, dtype=np.float32)
self._checkGradReal(self._tfFFTForRank(rank), re, use_gpu=True)
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 95a2806330..47ff7018f2 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -33,6 +33,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework import tensor_util as _tensor_util
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import math_ops as _math_ops
@@ -70,6 +71,52 @@ def _infer_fft_length_for_irfft(input_tensor, fft_rank):
return _ops.convert_to_tensor(fft_length, _dtypes.int32)
+def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False):
+ """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims."""
+ fft_shape = _tensor_util.constant_value_as_shape(fft_length)
+
+ # Edge case: skip padding empty tensors.
+ if (input_tensor.shape.ndims is not None and
+ any(dim.value == 0 for dim in input_tensor.shape)):
+ return input_tensor
+
+ # If we know the shapes ahead of time, we can either skip or pre-compute the
+ # appropriate paddings. Otherwise, fall back to computing paddings in
+ # TensorFlow.
+ if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None:
+ # Slice the last FFT-rank dimensions from input_tensor's shape.
+ input_fft_shape = input_tensor.shape[-fft_shape.ndims:]
+
+ if input_fft_shape.is_fully_defined():
+ # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
+ if is_reverse:
+ fft_shape = fft_shape[:-1].concatenate(fft_shape[-1].value // 2 + 1)
+
+ paddings = [[0, max(fft_dim.value - input_dim.value, 0)]
+ for fft_dim, input_dim in zip(fft_shape, input_fft_shape)]
+ if any(pad > 0 for _, pad in paddings):
+ outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims -
+ fft_shape.ndims), 0)
+ return _array_ops.pad(input_tensor, outer_paddings + paddings)
+ return input_tensor
+
+ # If we can't determine the paddings ahead of time, then we have to pad. If
+ # the paddings end up as zero, tf.pad has a special-case that does no work.
+ input_rank = _array_ops.rank(input_tensor)
+ input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:]
+ outer_dims = _math_ops.maximum(0, input_rank - fft_rank)
+ outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype)
+ # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
+ if is_reverse:
+ fft_length = _array_ops.concat([fft_length[:-1],
+ fft_length[-1:] // 2 + 1], 0)
+ fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape)
+ paddings = _array_ops.concat([outer_paddings, fft_paddings], 0)
+ paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings],
+ axis=1)
+ return _array_ops.pad(input_tensor, paddings)
+
+
def _rfft_wrapper(fft_fn, fft_rank, default_name):
"""Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
@@ -77,10 +124,12 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
+ input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
+ input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
return fft_fn(input_tensor, fft_length, name)
_rfft.__doc__ = fft_fn.__doc__
return _rfft
@@ -93,10 +142,13 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
+ input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
+ input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
+ is_reverse=True)
return ifft_fn(input_tensor, fft_length, name)
_irfft.__doc__ = ifft_fn.__doc__
return _irfft