diff options
author | RJ Ryan <rjryan@google.com> | 2017-07-19 11:28:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-19 11:32:53 -0700 |
commit | d70b39b5bd6eb60fcf0df7420969474c29fdaad7 (patch) | |
tree | c3bd9c01631d5679a3d3565ad8514016c9705ba4 /tensorflow/contrib/signal | |
parent | b9ea027b991a5d60b3581f2b981357276b8993b6 (diff) |
Add Short-time Fourier Transform (STFT) and inverse STFT support to tf.contrib.signal.
These are implemented using GPU-capable ops and have gradient support.
PiperOrigin-RevId: 162511440
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/BUILD | 21 | ||||
-rw-r--r-- | tensorflow/contrib/signal/__init__.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py | 250 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/spectral_ops.py | 165 |
4 files changed, 440 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 706de58a0a..52813b76fb 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -16,6 +16,7 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:spectral_ops", "//tensorflow/python:util", "//third_party/py/numpy", ], @@ -55,6 +56,26 @@ cuda_py_tests( ) cuda_py_tests( + name = "spectral_ops_test", + size = "large", + srcs = ["python/kernel_tests/spectral_ops_test.py"], + additional_deps = [ + ":signal_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:spectral_ops_test_util", + ], +) + +cuda_py_tests( name = "window_ops_test", size = "small", srcs = ["python/kernel_tests/window_ops_test.py"], diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index d0f6c1d0c6..c0136346b7 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -17,7 +17,9 @@ @@frame @@hamming_window @@hann_window +@@inverse_stft @@overlap_and_add +@@stft """ from __future__ import absolute_import @@ -29,6 +31,8 @@ from tensorflow.contrib.signal.python.ops.shape_ops import frame # `frame` used to be named `frames`, which is a noun and not a verb. # Keep an alias to `frames` for backwards compatibility. from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames +from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft +from tensorflow.contrib.signal.python.ops.spectral_ops import stft from tensorflow.contrib.signal.python.ops.window_ops import hamming_window from tensorflow.contrib.signal.python.ops.window_ops import hann_window diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py new file mode 100644 index 0000000000..904924b5e4 --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -0,0 +1,250 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for spectral_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.signal.python.ops import spectral_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import spectral_ops_test_util +from tensorflow.python.platform import test + + +class SpectralOpsTest(test.TestCase): + + @staticmethod + def _np_hann_periodic_window(length): + if length == 1: + return np.ones(1) + odd = length % 2 + if not odd: + length += 1 + window = 0.5 - 0.5 * np.cos(2.0 * np.pi * np.arange(length) / (length - 1)) + if not odd: + window = window[:-1] + return window + + @staticmethod + def _np_frame(data, window_length, hop_length): + num_frames = 1 + int(np.floor((len(data) - window_length) // hop_length)) + shape = (num_frames, window_length) + strides = (data.strides[0] * hop_length, data.strides[0]) + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + @staticmethod + def _np_stft(data, fft_length, hop_length, window_length): + frames = SpectralOpsTest._np_frame(data, window_length, hop_length) + window = SpectralOpsTest._np_hann_periodic_window(window_length) + return np.fft.rfft(frames * window, fft_length) + + @staticmethod + def _np_inverse_stft(stft, fft_length, hop_length, window_length): + frames = np.fft.irfft(stft, fft_length)[..., :window_length] + window = SpectralOpsTest._np_hann_periodic_window(window_length) + return SpectralOpsTest._np_overlap_add(frames * window, hop_length) + + @staticmethod + def _np_overlap_add(stft, hop_length): + num_frames, window_length = np.shape(stft) + # Output length will be one complete window, plus another hop_length's + # worth of points for each additional window. + output_length = window_length + (num_frames - 1) * hop_length + output = np.zeros(output_length) + for i in range(num_frames): + output[i * hop_length:i * hop_length + window_length] += stft[i,] + return output + + def _compare(self, signal, frame_length, frame_step, fft_length): + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)) as sess: + actual_stft = spectral_ops.stft( + signal, frame_length, frame_step, fft_length, pad_end=False) + + actual_inverse_stft = spectral_ops.inverse_stft( + actual_stft, frame_length, frame_step, fft_length) + + actual_stft, actual_inverse_stft = sess.run( + [actual_stft, actual_inverse_stft]) + + expected_stft = SpectralOpsTest._np_stft( + signal, fft_length, frame_step, frame_length) + self.assertAllClose(expected_stft, actual_stft, 1e-4, 1e-4) + + expected_inverse_stft = SpectralOpsTest._np_inverse_stft( + expected_stft, fft_length, frame_step, frame_length) + self.assertAllClose( + expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4) + + def _compare_round_trip(self, signal, frame_length, frame_step, fft_length): + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)) as sess: + stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, + pad_end=False) + inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, + fft_length) + signal, inverse_stft = sess.run([signal, inverse_stft]) + + # Since the shapes can differ due to padding, pad both signals to the max + # of their lengths. + max_length = max(signal.shape[0], inverse_stft.shape[0]) + signal = np.pad(signal, (0, max_length - signal.shape[0]), "constant") + inverse_stft = np.pad(inverse_stft, + (0, max_length - inverse_stft.shape[0]), "constant") + + # Ignore the frame_length samples at either edge. + start = frame_length + end = signal.shape[0] - frame_length + ratio = signal[start:end] / inverse_stft[start:end] + + # Check that the inverse and original signal are equal up to a constant + # factor. + self.assertLess(np.var(ratio), 2e-5) + + def test_shapes(self): + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)): + signal = np.zeros((512,)).astype(np.float32) + + # If fft_length is not provided, the smallest enclosing power of 2 of + # frame_length (8) is used. + stft = spectral_ops.stft(signal, frame_length=7, frame_step=8, + pad_end=True) + self.assertAllEqual([64, 5], stft.shape.as_list()) + self.assertAllEqual([64, 5], stft.eval().shape) + + stft = spectral_ops.stft(signal, frame_length=8, frame_step=8, + pad_end=True) + self.assertAllEqual([64, 5], stft.shape.as_list()) + self.assertAllEqual([64, 5], stft.eval().shape) + + stft = spectral_ops.stft(signal, frame_length=8, frame_step=8, + fft_length=16, pad_end=True) + self.assertAllEqual([64, 9], stft.shape.as_list()) + self.assertAllEqual([64, 9], stft.eval().shape) + + stft = np.zeros((32, 9)).astype(np.complex64) + + inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, + fft_length=16, frame_step=8) + expected_length = (stft.shape[0] - 1) * 8 + 8 + self.assertAllEqual([None], inverse_stft.shape.as_list()) + self.assertAllEqual([expected_length], inverse_stft.eval().shape) + + def test_stft_and_inverse_stft(self): + """Test that spectral_ops.stft/inverse_stft match a NumPy implementation.""" + # Tuples of (signal_length, frame_length, frame_step, fft_length). + test_configs = [ + (512, 64, 32, 64), + (512, 64, 64, 64), + (512, 64, 25, 64), + (512, 25, 15, 36), + (123, 23, 5, 42), + ] + + for signal_length, frame_length, frame_step, fft_length in test_configs: + signal = np.random.random(signal_length).astype(np.float32) + self._compare(signal, frame_length, frame_step, fft_length) + + def test_stft_round_trip(self): + # Tuples of (signal_length, frame_length, frame_step, fft_length). + test_configs = [ + # 87.5% overlap. + (4096, 256, 32, 256), + # 75% overlap. + (4096, 256, 64, 256), + # Odd frame hop. + (4096, 128, 25, 128), + # Odd frame length. + (4096, 127, 32, 128), + ] + + for signal_length, frame_length, frame_step, fft_length in test_configs: + # Generate a 440Hz signal at 8kHz sample rate. + signal = math_ops.sin(2 * np.pi * 440 / 8000 * + math_ops.to_float(math_ops.range(signal_length))) + self._compare_round_trip(signal, frame_length, frame_step, fft_length) + + @staticmethod + def _compute_stft_gradient(signal, frame_length=32, frame_step=16, + fft_length=32): + """Computes the gradient of the STFT with respect to `signal`.""" + stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length) + magnitude_stft = math_ops.abs(stft) + loss = math_ops.reduce_sum(magnitude_stft) + return gradients_impl.gradients([loss], [signal])[0] + + def test_gradients(self): + """Test that spectral_ops.stft has a working gradient.""" + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)) as sess: + signal_length = 512 + + # An all-zero signal has all zero gradients with respect to the sum of the + # magnitude STFT. + empty_signal = array_ops.zeros([signal_length], dtype=dtypes.float32) + empty_signal_gradient = sess.run( + self._compute_stft_gradient(empty_signal)) + self.assertTrue((empty_signal_gradient == 0.0).all()) + + # A sinusoid will have non-zero components of its gradient with respect to + # the sum of the magnitude STFT. + sinusoid = math_ops.sin( + 2 * np.pi * math_ops.linspace(0.0, 1.0, signal_length)) + sinusoid_gradient = sess.run(self._compute_stft_gradient(sinusoid)) + self.assertFalse((sinusoid_gradient == 0.0).all()) + + def test_gradients_numerical(self): + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)): + # Tuples of (signal_length, frame_length, frame_step, fft_length, + # stft_bound, inverse_stft_bound). + # TODO(rjryan): Investigate why STFT gradient error is so high. + test_configs = [ + (512, 64, 32, 64, 2e-3, 3e-5), + (512, 64, 64, 64, 2e-3, 3e-5), + (512, 64, 25, 64, 2e-3, 3e-5), + (512, 25, 15, 36, 8e-4, 3e-5), + (123, 23, 5, 42, 8e-4, 4e-5), + ] + + for (signal_length, frame_length, frame_step, fft_length, + stft_bound, inverse_stft_bound) in test_configs: + signal_shape = [signal_length] + signal = random_ops.random_uniform(signal_shape) + stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step), + fft_length // 2 + 1] + stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, + pad_end=False) + inverse_stft_shape = [(stft_shape[0] - 1) * frame_step + frame_length] + inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, + fft_length) + stft_error = test.compute_gradient_error(signal, [signal_length], + stft, stft_shape) + inverse_stft_error = test.compute_gradient_error( + stft, stft_shape, inverse_stft, inverse_stft_shape) + self.assertLess(stft_error, stft_bound) + self.assertLess(inverse_stft_error, inverse_stft_bound) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py new file mode 100644 index 0000000000..1457f13604 --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Spectral operations (e.g. Short-time Fourier Transform).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.contrib.signal.python.ops import reconstruction_ops +from tensorflow.contrib.signal.python.ops import shape_ops +from tensorflow.contrib.signal.python.ops import window_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import spectral_ops + + +def stft(signal, frame_length, frame_step, fft_length=None, + window_fn=functools.partial(window_ops.hann_window, periodic=True), + pad_end=True, name=None): + """Computes the Short-time Fourier Transform of a batch of real signals. + + https://en.wikipedia.org/wiki/Short-time_Fourier_transform + + Implemented with GPU-compatible ops and supports gradients. + + Args: + signal: A `[..., samples]` `float32` `Tensor` of real-valued signals. + frame_length: An integer scalar `Tensor`. The window length in samples. + frame_step: An integer scalar `Tensor`. The number of samples to step. + fft_length: An integer scalar `Tensor`. The size of the FFT to apply. + If not provided, uses the smallest power of 2 enclosing `frame_length`. + window_fn: A callable that takes a window length and a `dtype` keyword + argument and returns a `[window_length]` `Tensor` of samples in the + provided datatype. If set to `None`, no windowing is used. + pad_end: Whether to pad the end of signal with zeros when the provided + window length and hop produces a window that lies partially past the end + of `signal`. + name: An optional name for the operation. + + Returns: + A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where + `fft_unique_bins` is `fft_length / 2 + 1` (the unique components of the + FFT). + + Raises: + ValueError: If `signal` is not at least rank 1, `frame_length` is + not scalar, `frame_step` is not scalar, or `frame_length` + is greater than `fft_length`. + """ + with ops.name_scope(name, 'stft', [signal, frame_length, + frame_step]): + signal = ops.convert_to_tensor(signal, name='signal') + signal.shape.with_rank_at_least(1) + frame_length = ops.convert_to_tensor(frame_length, name='frame_length') + frame_length.shape.assert_has_rank(0) + frame_step = ops.convert_to_tensor(frame_step, name='frame_step') + frame_step.shape.assert_has_rank(0) + + if fft_length is None: + fft_length = _enclosing_power_of_two(frame_length) + else: + fft_length = ops.convert_to_tensor(fft_length, name='fft_length') + + frame_length_static = tensor_util.constant_value( + frame_length) + fft_length_static = tensor_util.constant_value(fft_length) + if (frame_length_static is not None and fft_length_static is not None and + frame_length_static > fft_length_static): + raise ValueError('frame_length (%d) may not be larger than ' + 'fft_length (%d)' % (frame_length_static, + fft_length_static)) + + framed_signal = shape_ops.frame( + signal, frame_length, frame_step, pad_end=pad_end) + + # Optionally window the framed signal. + if window_fn is not None: + window = window_fn(frame_length, dtype=framed_signal.dtype) + framed_signal *= window + + # spectral_ops.rfft produces the (fft_length/2 + 1) unique components of the + # FFT of the real windowed signals in framed_signal. + return spectral_ops.rfft(framed_signal, [fft_length]) + + +def inverse_stft(stfts, + frame_length, + frame_step, + fft_length, + window_fn=functools.partial(window_ops.hann_window, + periodic=True), + name=None): + """Computes the inverse Short-time Fourier Transform of a batch of STFTs. + + https://en.wikipedia.org/wiki/Short-time_Fourier_transform + + Implemented with GPU-compatible ops and supports gradients. + + Args: + stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins + representing a batch of `fft_length`-point STFTs. + frame_length: An integer scalar `Tensor`. The window length in samples. + frame_step: An integer scalar `Tensor`. The number of samples to step. + fft_length: An integer scalar `Tensor`. The size of the FFT that produced + `stfts`. + window_fn: A callable that takes a window length and a `dtype` keyword + argument and returns a `[window_length]` `Tensor` of samples in the + provided datatype. If set to `None`, no windowing is used. + name: An optional name for the operation. + + Returns: + A `[..., samples]` `Tensor` of `float32` signals representing the inverse + STFT for each input STFT in `stfts`. + + Raises: + ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, + `frame_step` is not scalar, or `fft_length` is not scalar. + """ + with ops.name_scope(name, 'inverse_stft', [stfts]): + stfts = ops.convert_to_tensor(stfts, name='stfts') + stfts.shape.with_rank_at_least(2) + frame_length = ops.convert_to_tensor(frame_length, name='frame_length') + frame_length.shape.assert_has_rank(0) + frame_step = ops.convert_to_tensor(frame_step, name='frame_step') + frame_step.shape.assert_has_rank(0) + fft_length = ops.convert_to_tensor(fft_length, name='fft_length') + fft_length.shape.assert_has_rank(0) + real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length] + + # Optionally window and overlap-add the inner 2 dimensions of real_frames + # into a single [samples] dimension. + if window_fn is not None: + window = window_fn(frame_length, dtype=stfts.dtype.real_dtype) + real_frames *= window + return reconstruction_ops.overlap_and_add(real_frames, frame_step) + + +def _enclosing_power_of_two(value): + """Return 2**N for integer N such that 2**N >= value.""" + value_static = tensor_util.constant_value(value) + if value_static is not None: + return constant_op.constant( + int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype) + return math_ops.cast( + math_ops.pow(2.0, math_ops.ceil( + math_ops.log(math_ops.to_float(value)) / math_ops.log(2.0))), + value.dtype) |