aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-19 11:28:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 11:32:53 -0700
commitd70b39b5bd6eb60fcf0df7420969474c29fdaad7 (patch)
treec3bd9c01631d5679a3d3565ad8514016c9705ba4 /tensorflow/contrib/signal
parentb9ea027b991a5d60b3581f2b981357276b8993b6 (diff)
Add Short-time Fourier Transform (STFT) and inverse STFT support to tf.contrib.signal.
These are implemented using GPU-capable ops and have gradient support. PiperOrigin-RevId: 162511440
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD21
-rw-r--r--tensorflow/contrib/signal/__init__.py4
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py250
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py165
4 files changed, 440 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 706de58a0a..52813b76fb 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -16,6 +16,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:spectral_ops",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
@@ -55,6 +56,26 @@ cuda_py_tests(
)
cuda_py_tests(
+ name = "spectral_ops_test",
+ size = "large",
+ srcs = ["python/kernel_tests/spectral_ops_test.py"],
+ additional_deps = [
+ ":signal_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:spectral_ops_test_util",
+ ],
+)
+
+cuda_py_tests(
name = "window_ops_test",
size = "small",
srcs = ["python/kernel_tests/window_ops_test.py"],
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index d0f6c1d0c6..c0136346b7 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -17,7 +17,9 @@
@@frame
@@hamming_window
@@hann_window
+@@inverse_stft
@@overlap_and_add
+@@stft
"""
from __future__ import absolute_import
@@ -29,6 +31,8 @@ from tensorflow.contrib.signal.python.ops.shape_ops import frame
# `frame` used to be named `frames`, which is a noun and not a verb.
# Keep an alias to `frames` for backwards compatibility.
from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames
+from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft
+from tensorflow.contrib.signal.python.ops.spectral_ops import stft
from tensorflow.contrib.signal.python.ops.window_ops import hamming_window
from tensorflow.contrib.signal.python.ops.window_ops import hann_window
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
new file mode 100644
index 0000000000..904924b5e4
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
@@ -0,0 +1,250 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for spectral_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.signal.python.ops import spectral_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import spectral_ops_test_util
+from tensorflow.python.platform import test
+
+
+class SpectralOpsTest(test.TestCase):
+
+ @staticmethod
+ def _np_hann_periodic_window(length):
+ if length == 1:
+ return np.ones(1)
+ odd = length % 2
+ if not odd:
+ length += 1
+ window = 0.5 - 0.5 * np.cos(2.0 * np.pi * np.arange(length) / (length - 1))
+ if not odd:
+ window = window[:-1]
+ return window
+
+ @staticmethod
+ def _np_frame(data, window_length, hop_length):
+ num_frames = 1 + int(np.floor((len(data) - window_length) // hop_length))
+ shape = (num_frames, window_length)
+ strides = (data.strides[0] * hop_length, data.strides[0])
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
+
+ @staticmethod
+ def _np_stft(data, fft_length, hop_length, window_length):
+ frames = SpectralOpsTest._np_frame(data, window_length, hop_length)
+ window = SpectralOpsTest._np_hann_periodic_window(window_length)
+ return np.fft.rfft(frames * window, fft_length)
+
+ @staticmethod
+ def _np_inverse_stft(stft, fft_length, hop_length, window_length):
+ frames = np.fft.irfft(stft, fft_length)[..., :window_length]
+ window = SpectralOpsTest._np_hann_periodic_window(window_length)
+ return SpectralOpsTest._np_overlap_add(frames * window, hop_length)
+
+ @staticmethod
+ def _np_overlap_add(stft, hop_length):
+ num_frames, window_length = np.shape(stft)
+ # Output length will be one complete window, plus another hop_length's
+ # worth of points for each additional window.
+ output_length = window_length + (num_frames - 1) * hop_length
+ output = np.zeros(output_length)
+ for i in range(num_frames):
+ output[i * hop_length:i * hop_length + window_length] += stft[i,]
+ return output
+
+ def _compare(self, signal, frame_length, frame_step, fft_length):
+ with spectral_ops_test_util.fft_kernel_label_map(), (
+ self.test_session(use_gpu=True)) as sess:
+ actual_stft = spectral_ops.stft(
+ signal, frame_length, frame_step, fft_length, pad_end=False)
+
+ actual_inverse_stft = spectral_ops.inverse_stft(
+ actual_stft, frame_length, frame_step, fft_length)
+
+ actual_stft, actual_inverse_stft = sess.run(
+ [actual_stft, actual_inverse_stft])
+
+ expected_stft = SpectralOpsTest._np_stft(
+ signal, fft_length, frame_step, frame_length)
+ self.assertAllClose(expected_stft, actual_stft, 1e-4, 1e-4)
+
+ expected_inverse_stft = SpectralOpsTest._np_inverse_stft(
+ expected_stft, fft_length, frame_step, frame_length)
+ self.assertAllClose(
+ expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4)
+
+ def _compare_round_trip(self, signal, frame_length, frame_step, fft_length):
+ with spectral_ops_test_util.fft_kernel_label_map(), (
+ self.test_session(use_gpu=True)) as sess:
+ stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
+ pad_end=False)
+ inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
+ fft_length)
+ signal, inverse_stft = sess.run([signal, inverse_stft])
+
+ # Since the shapes can differ due to padding, pad both signals to the max
+ # of their lengths.
+ max_length = max(signal.shape[0], inverse_stft.shape[0])
+ signal = np.pad(signal, (0, max_length - signal.shape[0]), "constant")
+ inverse_stft = np.pad(inverse_stft,
+ (0, max_length - inverse_stft.shape[0]), "constant")
+
+ # Ignore the frame_length samples at either edge.
+ start = frame_length
+ end = signal.shape[0] - frame_length
+ ratio = signal[start:end] / inverse_stft[start:end]
+
+ # Check that the inverse and original signal are equal up to a constant
+ # factor.
+ self.assertLess(np.var(ratio), 2e-5)
+
+ def test_shapes(self):
+ with spectral_ops_test_util.fft_kernel_label_map(), (
+ self.test_session(use_gpu=True)):
+ signal = np.zeros((512,)).astype(np.float32)
+
+ # If fft_length is not provided, the smallest enclosing power of 2 of
+ # frame_length (8) is used.
+ stft = spectral_ops.stft(signal, frame_length=7, frame_step=8,
+ pad_end=True)
+ self.assertAllEqual([64, 5], stft.shape.as_list())
+ self.assertAllEqual([64, 5], stft.eval().shape)
+
+ stft = spectral_ops.stft(signal, frame_length=8, frame_step=8,
+ pad_end=True)
+ self.assertAllEqual([64, 5], stft.shape.as_list())
+ self.assertAllEqual([64, 5], stft.eval().shape)
+
+ stft = spectral_ops.stft(signal, frame_length=8, frame_step=8,
+ fft_length=16, pad_end=True)
+ self.assertAllEqual([64, 9], stft.shape.as_list())
+ self.assertAllEqual([64, 9], stft.eval().shape)
+
+ stft = np.zeros((32, 9)).astype(np.complex64)
+
+ inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8,
+ fft_length=16, frame_step=8)
+ expected_length = (stft.shape[0] - 1) * 8 + 8
+ self.assertAllEqual([None], inverse_stft.shape.as_list())
+ self.assertAllEqual([expected_length], inverse_stft.eval().shape)
+
+ def test_stft_and_inverse_stft(self):
+ """Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
+ # Tuples of (signal_length, frame_length, frame_step, fft_length).
+ test_configs = [
+ (512, 64, 32, 64),
+ (512, 64, 64, 64),
+ (512, 64, 25, 64),
+ (512, 25, 15, 36),
+ (123, 23, 5, 42),
+ ]
+
+ for signal_length, frame_length, frame_step, fft_length in test_configs:
+ signal = np.random.random(signal_length).astype(np.float32)
+ self._compare(signal, frame_length, frame_step, fft_length)
+
+ def test_stft_round_trip(self):
+ # Tuples of (signal_length, frame_length, frame_step, fft_length).
+ test_configs = [
+ # 87.5% overlap.
+ (4096, 256, 32, 256),
+ # 75% overlap.
+ (4096, 256, 64, 256),
+ # Odd frame hop.
+ (4096, 128, 25, 128),
+ # Odd frame length.
+ (4096, 127, 32, 128),
+ ]
+
+ for signal_length, frame_length, frame_step, fft_length in test_configs:
+ # Generate a 440Hz signal at 8kHz sample rate.
+ signal = math_ops.sin(2 * np.pi * 440 / 8000 *
+ math_ops.to_float(math_ops.range(signal_length)))
+ self._compare_round_trip(signal, frame_length, frame_step, fft_length)
+
+ @staticmethod
+ def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
+ fft_length=32):
+ """Computes the gradient of the STFT with respect to `signal`."""
+ stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
+ magnitude_stft = math_ops.abs(stft)
+ loss = math_ops.reduce_sum(magnitude_stft)
+ return gradients_impl.gradients([loss], [signal])[0]
+
+ def test_gradients(self):
+ """Test that spectral_ops.stft has a working gradient."""
+ with spectral_ops_test_util.fft_kernel_label_map(), (
+ self.test_session(use_gpu=True)) as sess:
+ signal_length = 512
+
+ # An all-zero signal has all zero gradients with respect to the sum of the
+ # magnitude STFT.
+ empty_signal = array_ops.zeros([signal_length], dtype=dtypes.float32)
+ empty_signal_gradient = sess.run(
+ self._compute_stft_gradient(empty_signal))
+ self.assertTrue((empty_signal_gradient == 0.0).all())
+
+ # A sinusoid will have non-zero components of its gradient with respect to
+ # the sum of the magnitude STFT.
+ sinusoid = math_ops.sin(
+ 2 * np.pi * math_ops.linspace(0.0, 1.0, signal_length))
+ sinusoid_gradient = sess.run(self._compute_stft_gradient(sinusoid))
+ self.assertFalse((sinusoid_gradient == 0.0).all())
+
+ def test_gradients_numerical(self):
+ with spectral_ops_test_util.fft_kernel_label_map(), (
+ self.test_session(use_gpu=True)):
+ # Tuples of (signal_length, frame_length, frame_step, fft_length,
+ # stft_bound, inverse_stft_bound).
+ # TODO(rjryan): Investigate why STFT gradient error is so high.
+ test_configs = [
+ (512, 64, 32, 64, 2e-3, 3e-5),
+ (512, 64, 64, 64, 2e-3, 3e-5),
+ (512, 64, 25, 64, 2e-3, 3e-5),
+ (512, 25, 15, 36, 8e-4, 3e-5),
+ (123, 23, 5, 42, 8e-4, 4e-5),
+ ]
+
+ for (signal_length, frame_length, frame_step, fft_length,
+ stft_bound, inverse_stft_bound) in test_configs:
+ signal_shape = [signal_length]
+ signal = random_ops.random_uniform(signal_shape)
+ stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step),
+ fft_length // 2 + 1]
+ stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
+ pad_end=False)
+ inverse_stft_shape = [(stft_shape[0] - 1) * frame_step + frame_length]
+ inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
+ fft_length)
+ stft_error = test.compute_gradient_error(signal, [signal_length],
+ stft, stft_shape)
+ inverse_stft_error = test.compute_gradient_error(
+ stft, stft_shape, inverse_stft, inverse_stft_shape)
+ self.assertLess(stft_error, stft_bound)
+ self.assertLess(inverse_stft_error, inverse_stft_bound)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
new file mode 100644
index 0000000000..1457f13604
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -0,0 +1,165 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Spectral operations (e.g. Short-time Fourier Transform)."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+
+from tensorflow.contrib.signal.python.ops import reconstruction_ops
+from tensorflow.contrib.signal.python.ops import shape_ops
+from tensorflow.contrib.signal.python.ops import window_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import spectral_ops
+
+
+def stft(signal, frame_length, frame_step, fft_length=None,
+ window_fn=functools.partial(window_ops.hann_window, periodic=True),
+ pad_end=True, name=None):
+ """Computes the Short-time Fourier Transform of a batch of real signals.
+
+ https://en.wikipedia.org/wiki/Short-time_Fourier_transform
+
+ Implemented with GPU-compatible ops and supports gradients.
+
+ Args:
+ signal: A `[..., samples]` `float32` `Tensor` of real-valued signals.
+ frame_length: An integer scalar `Tensor`. The window length in samples.
+ frame_step: An integer scalar `Tensor`. The number of samples to step.
+ fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
+ If not provided, uses the smallest power of 2 enclosing `frame_length`.
+ window_fn: A callable that takes a window length and a `dtype` keyword
+ argument and returns a `[window_length]` `Tensor` of samples in the
+ provided datatype. If set to `None`, no windowing is used.
+ pad_end: Whether to pad the end of signal with zeros when the provided
+ window length and hop produces a window that lies partially past the end
+ of `signal`.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where
+ `fft_unique_bins` is `fft_length / 2 + 1` (the unique components of the
+ FFT).
+
+ Raises:
+ ValueError: If `signal` is not at least rank 1, `frame_length` is
+ not scalar, `frame_step` is not scalar, or `frame_length`
+ is greater than `fft_length`.
+ """
+ with ops.name_scope(name, 'stft', [signal, frame_length,
+ frame_step]):
+ signal = ops.convert_to_tensor(signal, name='signal')
+ signal.shape.with_rank_at_least(1)
+ frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
+ frame_length.shape.assert_has_rank(0)
+ frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
+ frame_step.shape.assert_has_rank(0)
+
+ if fft_length is None:
+ fft_length = _enclosing_power_of_two(frame_length)
+ else:
+ fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
+
+ frame_length_static = tensor_util.constant_value(
+ frame_length)
+ fft_length_static = tensor_util.constant_value(fft_length)
+ if (frame_length_static is not None and fft_length_static is not None and
+ frame_length_static > fft_length_static):
+ raise ValueError('frame_length (%d) may not be larger than '
+ 'fft_length (%d)' % (frame_length_static,
+ fft_length_static))
+
+ framed_signal = shape_ops.frame(
+ signal, frame_length, frame_step, pad_end=pad_end)
+
+ # Optionally window the framed signal.
+ if window_fn is not None:
+ window = window_fn(frame_length, dtype=framed_signal.dtype)
+ framed_signal *= window
+
+ # spectral_ops.rfft produces the (fft_length/2 + 1) unique components of the
+ # FFT of the real windowed signals in framed_signal.
+ return spectral_ops.rfft(framed_signal, [fft_length])
+
+
+def inverse_stft(stfts,
+ frame_length,
+ frame_step,
+ fft_length,
+ window_fn=functools.partial(window_ops.hann_window,
+ periodic=True),
+ name=None):
+ """Computes the inverse Short-time Fourier Transform of a batch of STFTs.
+
+ https://en.wikipedia.org/wiki/Short-time_Fourier_transform
+
+ Implemented with GPU-compatible ops and supports gradients.
+
+ Args:
+ stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
+ representing a batch of `fft_length`-point STFTs.
+ frame_length: An integer scalar `Tensor`. The window length in samples.
+ frame_step: An integer scalar `Tensor`. The number of samples to step.
+ fft_length: An integer scalar `Tensor`. The size of the FFT that produced
+ `stfts`.
+ window_fn: A callable that takes a window length and a `dtype` keyword
+ argument and returns a `[window_length]` `Tensor` of samples in the
+ provided datatype. If set to `None`, no windowing is used.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., samples]` `Tensor` of `float32` signals representing the inverse
+ STFT for each input STFT in `stfts`.
+
+ Raises:
+ ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
+ `frame_step` is not scalar, or `fft_length` is not scalar.
+ """
+ with ops.name_scope(name, 'inverse_stft', [stfts]):
+ stfts = ops.convert_to_tensor(stfts, name='stfts')
+ stfts.shape.with_rank_at_least(2)
+ frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
+ frame_length.shape.assert_has_rank(0)
+ frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
+ frame_step.shape.assert_has_rank(0)
+ fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
+ fft_length.shape.assert_has_rank(0)
+ real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]
+
+ # Optionally window and overlap-add the inner 2 dimensions of real_frames
+ # into a single [samples] dimension.
+ if window_fn is not None:
+ window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
+ real_frames *= window
+ return reconstruction_ops.overlap_and_add(real_frames, frame_step)
+
+
+def _enclosing_power_of_two(value):
+ """Return 2**N for integer N such that 2**N >= value."""
+ value_static = tensor_util.constant_value(value)
+ if value_static is not None:
+ return constant_op.constant(
+ int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype)
+ return math_ops.cast(
+ math_ops.pow(2.0, math_ops.ceil(
+ math_ops.log(math_ops.to_float(value)) / math_ops.log(2.0))),
+ value.dtype)