diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-25 14:55:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-25 14:59:40 -0700 |
commit | 6b3751f660dfa0675cc39d163b8224f2c070694e (patch) | |
tree | 8fc5c728fd23235a04e9ffd87c92720b052d263f /tensorflow/contrib/signal | |
parent | 27bbbd8d0149724caa7f1295f122f968acb8cd96 (diff) |
Make fft_length optional for inverse_stft
PiperOrigin-RevId: 163127500
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/ops/spectral_ops.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index 75bc0bd21d..950d8f471c 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -103,7 +103,7 @@ def stft(signals, frame_length, frame_step, fft_length=None, def inverse_stft(stfts, frame_length, frame_step, - fft_length, + fft_length=None, window_fn=functools.partial(window_ops.hann_window, periodic=True), name=None): @@ -118,7 +118,8 @@ def inverse_stft(stfts, frame_length: An integer scalar `Tensor`. The window length in samples. frame_step: An integer scalar `Tensor`. The number of samples to step. fft_length: An integer scalar `Tensor`. The size of the FFT that produced - `stfts`. + `stfts`. If not provided, uses the smallest power of 2 enclosing + `frame_length`. window_fn: A callable that takes a window length and a `dtype` keyword argument and returns a `[window_length]` `Tensor` of samples in the provided datatype. If set to `None`, no windowing is used. @@ -130,7 +131,8 @@ def inverse_stft(stfts, Raises: ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, - `frame_step` is not scalar, or `fft_length` is not scalar. + `frame_step` is not scalar, or `fft_length` is not scalar, or + `frame_length` is greater than `fft_length`. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -141,8 +143,21 @@ def inverse_stft(stfts, frame_length.shape.assert_has_rank(0) frame_step = ops.convert_to_tensor(frame_step, name='frame_step') frame_step.shape.assert_has_rank(0) - fft_length = ops.convert_to_tensor(fft_length, name='fft_length') - fft_length.shape.assert_has_rank(0) + if fft_length is None: + fft_length = _enclosing_power_of_two(frame_length) + else: + fft_length = ops.convert_to_tensor(fft_length, name='fft_length') + fft_length.shape.assert_has_rank(0) + + frame_length_static = tensor_util.constant_value( + frame_length) + fft_length_static = tensor_util.constant_value(fft_length) + if (frame_length_static is not None and fft_length_static is not None and + frame_length_static > fft_length_static): + raise ValueError('frame_length (%d) may not be larger than ' + 'fft_length (%d)' % (frame_length_static, + fft_length_static)) + real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length] # Optionally window and overlap-add the inner 2 dimensions of real_frames |