diff options
author | RJ Ryan <rjryan@google.com> | 2017-10-01 22:38:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-01 22:41:57 -0700 |
commit | 09d0c5fd8cd815d3bcaa883b0e63535a4a786533 (patch) | |
tree | dceddd2f2b3429b747b36283a86ca1cf4bad00de /tensorflow/contrib/signal | |
parent | 217e6a70b9a095974ed0e27b1848458edb232a3e (diff) |
[tf-signal] Remove checks that frame_length <= fft_length in stft and inverse_stft.
Also add tests for stft/inverse_stft when the shape/rank of the inputs are unknown.
Fixes GitHub Issue #13363.
PiperOrigin-RevId: 170662530
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/spectral_ops.py | 60 |
2 files changed, 65 insertions, 26 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 305a2b2eb9..72d317dc41 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -59,7 +59,11 @@ class SpectralOpsTest(test.TestCase): @staticmethod def _np_inverse_stft(stft, fft_length, hop_length, window_length): - frames = np.fft.irfft(stft, fft_length)[..., :window_length] + frames = np.fft.irfft(stft, fft_length) + # Pad or truncate frames's inner dimension to window_length. + frames = frames[..., :window_length] + frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) + + [[0, max(0, window_length - frames.shape[-1])]], "constant") window = SpectralOpsTest._np_hann_periodic_window(window_length) return SpectralOpsTest._np_overlap_add(frames * window, hop_length) @@ -79,12 +83,27 @@ class SpectralOpsTest(test.TestCase): self.test_session(use_gpu=True)) as sess: actual_stft = spectral_ops.stft( signal, frame_length, frame_step, fft_length, pad_end=False) + signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype)) + actual_stft_from_ph = spectral_ops.stft( + signal_ph, frame_length, frame_step, fft_length, pad_end=False) actual_inverse_stft = spectral_ops.inverse_stft( actual_stft, frame_length, frame_step, fft_length) - actual_stft, actual_inverse_stft = sess.run( - [actual_stft, actual_inverse_stft]) + actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run( + [actual_stft, actual_stft_from_ph, actual_inverse_stft], + feed_dict={signal_ph: signal}) + + actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype) + actual_inverse_stft_from_ph = sess.run( + spectral_ops.inverse_stft( + actual_stft_ph, frame_length, frame_step, fft_length), + feed_dict={actual_stft_ph: actual_stft}) + + # Confirm that there is no difference in output when shape/rank is fully + # unknown or known. + self.assertAllClose(actual_stft, actual_stft_from_ph) + self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph) expected_stft = SpectralOpsTest._np_stft( signal, fft_length, frame_step, frame_length) @@ -142,6 +161,11 @@ class SpectralOpsTest(test.TestCase): self.assertAllEqual([64, 9], stft.shape.as_list()) self.assertAllEqual([64, 9], stft.eval().shape) + stft = spectral_ops.stft(signal, frame_length=16, frame_step=8, + fft_length=8, pad_end=True) + self.assertAllEqual([64, 5], stft.shape.as_list()) + self.assertAllEqual([64, 5], stft.eval().shape) + stft = np.zeros((32, 9)).astype(np.complex64) inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, @@ -156,6 +180,7 @@ class SpectralOpsTest(test.TestCase): test_configs = [ (512, 64, 32, 64), (512, 64, 64, 64), + (512, 72, 64, 64), (512, 64, 25, 64), (512, 25, 15, 36), (123, 23, 5, 42), diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index 950d8f471c..5ed109b7dd 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -28,6 +28,7 @@ from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops @@ -59,8 +60,7 @@ def stft(signals, frame_length, frame_step, fft_length=None, Raises: ValueError: If `signals` is not at least rank 1, `frame_length` is - not scalar, `frame_step` is not scalar, or `frame_length` - is greater than `fft_length`. + not scalar, or `frame_step` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -78,15 +78,6 @@ def stft(signals, frame_length, frame_step, fft_length=None, else: fft_length = ops.convert_to_tensor(fft_length, name='fft_length') - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) - framed_signals = shape_ops.frame( signals, frame_length, frame_step, pad_end=pad_end) @@ -131,8 +122,7 @@ def inverse_stft(stfts, Raises: ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, - `frame_step` is not scalar, or `fft_length` is not scalar, or - `frame_length` is greater than `fft_length`. + `frame_step` is not scalar, or `fft_length` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -149,16 +139,40 @@ def inverse_stft(stfts, fft_length = ops.convert_to_tensor(fft_length, name='fft_length') fft_length.shape.assert_has_rank(0) - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) - - real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length] + real_frames = spectral_ops.irfft(stfts, [fft_length]) + + # frame_length may be larger or smaller than fft_length, so we pad or + # truncate real_frames to frame_length. + frame_length_static = tensor_util.constant_value(frame_length) + # If we don't know the shape of real_frames's inner dimension, pad and + # truncate to frame_length. + if (frame_length_static is None or + real_frames.shape.ndims is None or + real_frames.shape[-1].value is None): + real_frames = real_frames[..., :frame_length] + real_frames_rank = array_ops.rank(real_frames) + real_frames_shape = array_ops.shape(real_frames) + paddings = array_ops.concat( + [array_ops.zeros([real_frames_rank - 1, 2], + dtype=frame_length.dtype), + [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0) + real_frames = array_ops.pad(real_frames, paddings) + # We know real_frames's last dimension and frame_length statically. If they + # are different, then pad or truncate real_frames to frame_length. + elif real_frames.shape[-1].value > frame_length_static: + real_frames = real_frames[..., :frame_length_static] + elif real_frames.shape[-1].value < frame_length_static: + pad_amount = frame_length_static - real_frames.shape[-1].value + real_frames = array_ops.pad(real_frames, + [[0, 0]] * (real_frames.shape.ndims - 1) + + [[0, pad_amount]]) + + # The above code pads the inner dimension of real_frames to frame_length, + # but it does so in a way that may not be shape-inference friendly. + # Restore shape information if we are able to. + if frame_length_static is not None and real_frames.shape.ndims is not None: + real_frames.set_shape([None] * (real_frames.shape.ndims - 1) + + [frame_length_static]) # Optionally window and overlap-add the inner 2 dimensions of real_frames # into a single [samples] dimension. |