aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-25 14:55:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 14:59:40 -0700
commit6b3751f660dfa0675cc39d163b8224f2c070694e (patch)
tree8fc5c728fd23235a04e9ffd87c92720b052d263f /tensorflow/contrib/signal
parent27bbbd8d0149724caa7f1295f122f968acb8cd96 (diff)
Make fft_length optional for inverse_stft
PiperOrigin-RevId: 163127500
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
index 75bc0bd21d..950d8f471c 100644
--- a/tensorflow/contrib/signal/python/ops/spectral_ops.py
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -103,7 +103,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
def inverse_stft(stfts,
frame_length,
frame_step,
- fft_length,
+ fft_length=None,
window_fn=functools.partial(window_ops.hann_window,
periodic=True),
name=None):
@@ -118,7 +118,8 @@ def inverse_stft(stfts,
frame_length: An integer scalar `Tensor`. The window length in samples.
frame_step: An integer scalar `Tensor`. The number of samples to step.
fft_length: An integer scalar `Tensor`. The size of the FFT that produced
- `stfts`.
+ `stfts`. If not provided, uses the smallest power of 2 enclosing
+ `frame_length`.
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
@@ -130,7 +131,8 @@ def inverse_stft(stfts,
Raises:
ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
- `frame_step` is not scalar, or `fft_length` is not scalar.
+ `frame_step` is not scalar, or `fft_length` is not scalar, or
+ `frame_length` is greater than `fft_length`.
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@@ -141,8 +143,21 @@ def inverse_stft(stfts,
frame_length.shape.assert_has_rank(0)
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step.shape.assert_has_rank(0)
- fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
- fft_length.shape.assert_has_rank(0)
+ if fft_length is None:
+ fft_length = _enclosing_power_of_two(frame_length)
+ else:
+ fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
+ fft_length.shape.assert_has_rank(0)
+
+ frame_length_static = tensor_util.constant_value(
+ frame_length)
+ fft_length_static = tensor_util.constant_value(fft_length)
+ if (frame_length_static is not None and fft_length_static is not None and
+ frame_length_static > fft_length_static):
+ raise ValueError('frame_length (%d) may not be larger than '
+ 'fft_length (%d)' % (frame_length_static,
+ fft_length_static))
+
real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]
# Optionally window and overlap-add the inner 2 dimensions of real_frames