aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-10-01 22:38:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-01 22:41:57 -0700
commit09d0c5fd8cd815d3bcaa883b0e63535a4a786533 (patch)
treedceddd2f2b3429b747b36283a86ca1cf4bad00de /tensorflow/contrib/signal
parent217e6a70b9a095974ed0e27b1848458edb232a3e (diff)
[tf-signal] Remove checks that frame_length <= fft_length in stft and inverse_stft.
Also add tests for stft/inverse_stft when the shape/rank of the inputs are unknown. Fixes GitHub Issue #13363. PiperOrigin-RevId: 170662530
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py31
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py60
2 files changed, 65 insertions, 26 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
index 305a2b2eb9..72d317dc41 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
@@ -59,7 +59,11 @@ class SpectralOpsTest(test.TestCase):
@staticmethod
def _np_inverse_stft(stft, fft_length, hop_length, window_length):
- frames = np.fft.irfft(stft, fft_length)[..., :window_length]
+ frames = np.fft.irfft(stft, fft_length)
+ # Pad or truncate frames's inner dimension to window_length.
+ frames = frames[..., :window_length]
+ frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) +
+ [[0, max(0, window_length - frames.shape[-1])]], "constant")
window = SpectralOpsTest._np_hann_periodic_window(window_length)
return SpectralOpsTest._np_overlap_add(frames * window, hop_length)
@@ -79,12 +83,27 @@ class SpectralOpsTest(test.TestCase):
self.test_session(use_gpu=True)) as sess:
actual_stft = spectral_ops.stft(
signal, frame_length, frame_step, fft_length, pad_end=False)
+ signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype))
+ actual_stft_from_ph = spectral_ops.stft(
+ signal_ph, frame_length, frame_step, fft_length, pad_end=False)
actual_inverse_stft = spectral_ops.inverse_stft(
actual_stft, frame_length, frame_step, fft_length)
- actual_stft, actual_inverse_stft = sess.run(
- [actual_stft, actual_inverse_stft])
+ actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run(
+ [actual_stft, actual_stft_from_ph, actual_inverse_stft],
+ feed_dict={signal_ph: signal})
+
+ actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype)
+ actual_inverse_stft_from_ph = sess.run(
+ spectral_ops.inverse_stft(
+ actual_stft_ph, frame_length, frame_step, fft_length),
+ feed_dict={actual_stft_ph: actual_stft})
+
+ # Confirm that there is no difference in output when shape/rank is fully
+ # unknown or known.
+ self.assertAllClose(actual_stft, actual_stft_from_ph)
+ self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph)
expected_stft = SpectralOpsTest._np_stft(
signal, fft_length, frame_step, frame_length)
@@ -142,6 +161,11 @@ class SpectralOpsTest(test.TestCase):
self.assertAllEqual([64, 9], stft.shape.as_list())
self.assertAllEqual([64, 9], stft.eval().shape)
+ stft = spectral_ops.stft(signal, frame_length=16, frame_step=8,
+ fft_length=8, pad_end=True)
+ self.assertAllEqual([64, 5], stft.shape.as_list())
+ self.assertAllEqual([64, 5], stft.eval().shape)
+
stft = np.zeros((32, 9)).astype(np.complex64)
inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8,
@@ -156,6 +180,7 @@ class SpectralOpsTest(test.TestCase):
test_configs = [
(512, 64, 32, 64),
(512, 64, 64, 64),
+ (512, 72, 64, 64),
(512, 64, 25, 64),
(512, 25, 15, 36),
(123, 23, 5, 42),
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
index 950d8f471c..5ed109b7dd 100644
--- a/tensorflow/contrib/signal/python/ops/spectral_ops.py
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -28,6 +28,7 @@ from tensorflow.contrib.signal.python.ops import window_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops
@@ -59,8 +60,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
Raises:
ValueError: If `signals` is not at least rank 1, `frame_length` is
- not scalar, `frame_step` is not scalar, or `frame_length`
- is greater than `fft_length`.
+ not scalar, or `frame_step` is not scalar.
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@@ -78,15 +78,6 @@ def stft(signals, frame_length, frame_step, fft_length=None,
else:
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
- frame_length_static = tensor_util.constant_value(
- frame_length)
- fft_length_static = tensor_util.constant_value(fft_length)
- if (frame_length_static is not None and fft_length_static is not None and
- frame_length_static > fft_length_static):
- raise ValueError('frame_length (%d) may not be larger than '
- 'fft_length (%d)' % (frame_length_static,
- fft_length_static))
-
framed_signals = shape_ops.frame(
signals, frame_length, frame_step, pad_end=pad_end)
@@ -131,8 +122,7 @@ def inverse_stft(stfts,
Raises:
ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
- `frame_step` is not scalar, or `fft_length` is not scalar, or
- `frame_length` is greater than `fft_length`.
+ `frame_step` is not scalar, or `fft_length` is not scalar.
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@@ -149,16 +139,40 @@ def inverse_stft(stfts,
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
fft_length.shape.assert_has_rank(0)
- frame_length_static = tensor_util.constant_value(
- frame_length)
- fft_length_static = tensor_util.constant_value(fft_length)
- if (frame_length_static is not None and fft_length_static is not None and
- frame_length_static > fft_length_static):
- raise ValueError('frame_length (%d) may not be larger than '
- 'fft_length (%d)' % (frame_length_static,
- fft_length_static))
-
- real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]
+ real_frames = spectral_ops.irfft(stfts, [fft_length])
+
+ # frame_length may be larger or smaller than fft_length, so we pad or
+ # truncate real_frames to frame_length.
+ frame_length_static = tensor_util.constant_value(frame_length)
+ # If we don't know the shape of real_frames's inner dimension, pad and
+ # truncate to frame_length.
+ if (frame_length_static is None or
+ real_frames.shape.ndims is None or
+ real_frames.shape[-1].value is None):
+ real_frames = real_frames[..., :frame_length]
+ real_frames_rank = array_ops.rank(real_frames)
+ real_frames_shape = array_ops.shape(real_frames)
+ paddings = array_ops.concat(
+ [array_ops.zeros([real_frames_rank - 1, 2],
+ dtype=frame_length.dtype),
+ [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0)
+ real_frames = array_ops.pad(real_frames, paddings)
+ # We know real_frames's last dimension and frame_length statically. If they
+ # are different, then pad or truncate real_frames to frame_length.
+ elif real_frames.shape[-1].value > frame_length_static:
+ real_frames = real_frames[..., :frame_length_static]
+ elif real_frames.shape[-1].value < frame_length_static:
+ pad_amount = frame_length_static - real_frames.shape[-1].value
+ real_frames = array_ops.pad(real_frames,
+ [[0, 0]] * (real_frames.shape.ndims - 1) +
+ [[0, pad_amount]])
+
+ # The above code pads the inner dimension of real_frames to frame_length,
+ # but it does so in a way that may not be shape-inference friendly.
+ # Restore shape information if we are able to.
+ if frame_length_static is not None and real_frames.shape.ndims is not None:
+ real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
+ [frame_length_static])
# Optionally window and overlap-add the inner 2 dimensions of real_frames
# into a single [samples] dimension.