aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-09-13 10:37:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 10:41:44 -0700
commita4f6e7c1afd130d97759b99ba88e69138c59107c (patch)
tree6b2a66ad636e34e4806be4dd34d0f13fd6dac408 /tensorflow/contrib/signal
parentb00b6d23c8a243a3da7e7b145467053321747e92 (diff)
Add mel-scale conversion matrix support to tf.contrib.signal.
PiperOrigin-RevId: 168560255
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD10
-rw-r--r--tensorflow/contrib/signal/__init__.py3
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py164
-rw-r--r--tensorflow/contrib/signal/python/ops/mel_ops.py199
4 files changed, 376 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 05e4fbfbbc..8c11cf0d64 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -25,6 +25,16 @@ py_library(
)
cuda_py_tests(
+ name = "mel_ops_test",
+ srcs = ["python/kernel_tests/mel_ops_test.py"],
+ additional_deps = [
+ ":signal_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+cuda_py_tests(
name = "reconstruction_ops_test",
srcs = ["python/kernel_tests/reconstruction_ops_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 17e6379b9e..25123b097e 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -20,11 +20,13 @@ See the @{$python/contrib.signal} guide.
@@hamming_window
@@hann_window
@@inverse_stft
+@@linear_to_mel_weight_matrix
@@overlap_and_add
@@stft
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
+[mel]: https://en.wikipedia.org/wiki/Mel_scale
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@@ -32,6 +34,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix
from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add
from tensorflow.contrib.signal.python.ops.shape_ops import frame
# `frame` used to be named `frames`, which is a noun and not a verb.
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
new file mode 100644
index 0000000000..0448ff5bb0
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
@@ -0,0 +1,164 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for mel_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.signal.python.ops import mel_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+# mel spectrum constants and functions.
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+def hertz_to_mel(frequencies_hertz):
+ """Convert frequencies to mel scale using HTK formula.
+
+ Copied from
+ https://github.com/tensorflow/models/blob/master/audioset/mel_features.py.
+
+ Args:
+ frequencies_hertz: Scalar or np.array of frequencies in hertz.
+
+ Returns:
+ Object of same size as frequencies_hertz containing corresponding values
+ on the mel scale.
+ """
+ return _MEL_HIGH_FREQUENCY_Q * np.log(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def spectrogram_to_mel_matrix(num_mel_bins=20,
+ num_spectrogram_bins=129,
+ audio_sample_rate=8000,
+ lower_edge_hertz=125.0,
+ upper_edge_hertz=3800.0):
+ """Return a matrix that can post-multiply spectrogram rows to make mel.
+
+ Copied from
+ https://github.com/tensorflow/models/blob/master/audioset/mel_features.py.
+
+ Returns a np.array matrix A that can be used to post-multiply a matrix S of
+ spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
+ "mel spectrogram" M of frames x num_mel_bins. M = S A.
+
+ The classic HTK algorithm exploits the complementarity of adjacent mel bands
+ to multiply each FFT bin by only one mel weight, then add it, with positive
+ and negative signs, to the two adjacent mel bands to which that bin
+ contributes. Here, by expressing this operation as a matrix multiply, we go
+ from num_fft multiplies per frame (plus around 2*num_fft adds) to around
+ num_fft^2 multiplies and adds. However, because these are all presumably
+ accomplished in a single call to np.dot(), it's not clear which approach is
+ faster in Python. The matrix multiplication has the attraction of being more
+ general and flexible, and much easier to read.
+
+ Args:
+ num_mel_bins: How many bands in the resulting mel spectrum. This is
+ the number of columns in the output matrix.
+ num_spectrogram_bins: How many bins there are in the source spectrogram
+ data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
+ only contains the nonredundant FFT bins.
+ audio_sample_rate: Samples per second of the audio at the input to the
+ spectrogram. We need this to figure out the actual frequencies for
+ each spectrogram bin, which dictates how they are mapped into mel.
+ lower_edge_hertz: Lower bound on the frequencies to be included in the mel
+ spectrum. This corresponds to the lower edge of the lowest triangular
+ band.
+ upper_edge_hertz: The desired top edge of the highest frequency band.
+
+ Returns:
+ An np.array with shape (num_spectrogram_bins, num_mel_bins).
+
+ Raises:
+ ValueError: if frequency edges are incorrectly ordered.
+ """
+ nyquist_hertz = audio_sample_rate / 2.
+ if lower_edge_hertz >= upper_edge_hertz:
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
+ (lower_edge_hertz, upper_edge_hertz))
+ spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
+ spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
+ # The i'th mel band (starting from i=1) has center frequency
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
+ # the band_edges_mel arrays.
+ band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
+ hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
+ # of spectrogram values.
+ mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
+ for i in range(num_mel_bins):
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
+ # Calculate lower and upper slopes for every spectrogram bin.
+ # Line segments are linear in the *mel* domain, not hertz.
+ lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
+ (center_mel - lower_edge_mel))
+ upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
+ (upper_edge_mel - center_mel))
+ # .. then intersect them with each other and zero.
+ mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
+ upper_slope))
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
+ # coefficient.
+ mel_weights_matrix[0, :] = 0.0
+ return mel_weights_matrix
+
+
+class LinearToMelTest(test.TestCase):
+
+ def test_matches_reference_implementation(self):
+ # Tuples of (num_mel_bins, num_spectrogram_bins, sample_rate,
+ # lower_edge_hertz, upper_edge_hertz) to test.
+ configs = [
+ # Defaults.
+ (20, 129, 8000.0, 125.0, 3800.0),
+ # Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
+ (80, 1025, 24000.0, 80.0, 12000.0)
+ ]
+ with self.test_session(use_gpu=True):
+ for config in configs:
+ mel_matrix_np = spectrogram_to_mel_matrix(*config)
+ mel_matrix = mel_ops.linear_to_mel_weight_matrix(*config)
+ self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6)
+
+ def test_dtypes(self):
+ for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ self.assertEqual(dtype,
+ mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
+
+ def test_error(self):
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100,
+ upper_edge_hertz=10)
+ with self.assertRaises(ValueError):
+ mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py
new file mode 100644
index 0000000000..2ad07027aa
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/mel_ops.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""mel conversion ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.signal.python.ops import shape_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+# mel spectrum constants.
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+def _mel_to_hertz(mel_values, name=None):
+ """Converts frequencies in `mel_values` from the mel scale to linear scale.
+
+ Args:
+ mel_values: A `Tensor` of frequencies in the mel scale.
+ name: An optional name for the operation.
+
+ Returns:
+ A `Tensor` of the same shape and type as `mel_values` containing linear
+ scale frequencies in Hertz.
+ """
+ with ops.name_scope(name, 'mel_to_hertz', [mel_values]):
+ mel_values = ops.convert_to_tensor(mel_values)
+ return _MEL_BREAK_FREQUENCY_HERTZ * (
+ math_ops.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
+ )
+
+
+def _hertz_to_mel(frequencies_hertz, name=None):
+ """Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.
+
+ Args:
+ frequencies_hertz: A `Tensor` of frequencies in Hertz.
+ name: An optional name for the operation.
+
+ Returns:
+ A `Tensor` of the same shape and type of `frequencies_hertz` containing
+ frequencies in the mel scale.
+ """
+ with ops.name_scope(name, 'hertz_to_mel', [frequencies_hertz]):
+ frequencies_hertz = ops.convert_to_tensor(frequencies_hertz)
+ return _MEL_HIGH_FREQUENCY_Q * math_ops.log(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
+ lower_edge_hertz, upper_edge_hertz, dtype):
+ """Checks the inputs to linear_to_mel_weight_matrix."""
+ if num_mel_bins <= 0:
+ raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
+ if num_spectrogram_bins <= 0:
+ raise ValueError('num_spectrogram_bins must be positive. Got: %s' %
+ num_spectrogram_bins)
+ if sample_rate <= 0.0:
+ raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
+ if lower_edge_hertz < 0.0:
+ raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
+ lower_edge_hertz)
+ if lower_edge_hertz >= upper_edge_hertz:
+ raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
+ (lower_edge_hertz, upper_edge_hertz))
+ if not dtype.is_floating:
+ raise ValueError('dtype must be a floating point type. Got: %s' % dtype)
+
+
+def linear_to_mel_weight_matrix(num_mel_bins=20,
+ num_spectrogram_bins=129,
+ sample_rate=8000,
+ lower_edge_hertz=125.0,
+ upper_edge_hertz=3800.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Returns a matrix to warp linear scale spectrograms to the [mel scale][mel].
+
+ Returns a weight matrix that can be used to re-weight a `Tensor` containing
+ `num_spectrogram_bins` linearly sampled frequency information from
+ `[0, sample_rate / 2]` into `num_mel_bins` frequency information from
+ `[lower_edge_hertz, upper_edge_hertz]` on the [mel scale][mel].
+
+ For example, the returned matrix `A` can be used to right-multiply a
+ spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear
+ scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram"
+ `M` of shape `[frames, num_mel_bins]`.
+
+ # `S` has shape [frames, num_spectrogram_bins]
+ # `M` has shape [frames, num_mel_bins]
+ M = tf.matmul(S, A)
+
+ The matrix can be used with @{tf.tensordot} to convert an arbitrary rank
+ `Tensor` of linear-scale spectral bins into the mel scale.
+
+ # S has shape [..., num_spectrogram_bins].
+ # M has shape [..., num_mel_bins].
+ M = tf.tensordot(S, A, 1)
+ # tf.tensordot does not support shape inference for this case yet.
+ M.set_shape(S.shape[:-1].concatenate(A.shape[-1:]))
+
+ Args:
+ num_mel_bins: Python int. How many bands in the resulting mel spectrum.
+ num_spectrogram_bins: Python int. How many bins there are in the source
+ spectrogram data, which is understood to be `fft_size // 2 + 1`, i.e. the
+ spectrogram only contains the nonredundant FFT bins.
+ sample_rate: Python float. Samples per second of the input signal used to
+ create the spectrogram. We need this to figure out the actual frequencies
+ for each spectrogram bin, which dictates how they are mapped into the mel
+ scale.
+ lower_edge_hertz: Python float. Lower bound on the frequencies to be
+ included in the mel spectrum. This corresponds to the lower edge of the
+ lowest triangular band.
+ upper_edge_hertz: Python float. The desired top edge of the highest
+ frequency band.
+ dtype: The `DType` of the result matrix. Must be a floating point type.
+ name: An optional name for the operation.
+
+ Returns:
+ A `Tensor` of shape `[num_spectrogram_bins, num_mel_bins]`.
+
+ Raises:
+ ValueError: If num_mel_bins/num_spectrogram_bins/sample_rate are not
+ positive, lower_edge_hertz is negative, or frequency edges are incorrectly
+ ordered.
+
+ [mel]: https://en.wikipedia.org/wiki/Mel_scale
+ """
+ with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
+ _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
+ lower_edge_hertz, upper_edge_hertz, dtype)
+
+ # To preserve accuracy, we compute the matrix at float64 precision and then
+ # cast to `dtype` at the end. This function can be constant folded by graph
+ # optimization since there are no Tensor inputs.
+ sample_rate = ops.convert_to_tensor(
+ sample_rate, dtypes.float64, name='sample_rate')
+ lower_edge_hertz = ops.convert_to_tensor(
+ lower_edge_hertz, dtypes.float64, name='lower_edge_hertz')
+ upper_edge_hertz = ops.convert_to_tensor(
+ upper_edge_hertz, dtypes.float64, name='upper_edge_hertz')
+ zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64)
+
+ # HTK excludes the spectrogram DC bin.
+ bands_to_zero = 1
+ nyquist_hertz = sample_rate / 2.0
+ linear_frequencies = math_ops.linspace(
+ zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
+ spectrogram_bins_mel = array_ops.expand_dims(
+ _hertz_to_mel(linear_frequencies), 1)
+
+ # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
+ # center of each band is the lower and upper edge of the adjacent bands.
+ # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
+ # num_mel_bins + 2 pieces.
+ band_edges_mel = shape_ops.frame(
+ math_ops.linspace(_hertz_to_mel(lower_edge_hertz),
+ _hertz_to_mel(upper_edge_hertz),
+ num_mel_bins + 2), frame_length=3, frame_step=1)
+
+ # Split the triples up and reshape them into [1, num_mel_bins] tensors.
+ lower_edge_mel, center_mel, upper_edge_mel = tuple(array_ops.reshape(
+ t, [1, num_mel_bins]) for t in array_ops.split(
+ band_edges_mel, 3, axis=1))
+
+ # Calculate lower and upper slopes for every spectrogram bin.
+ # Line segments are linear in the mel domain, not Hertz.
+ lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
+ center_mel - lower_edge_mel)
+ upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
+ upper_edge_mel - center_mel)
+
+ # Intersect the line segments with each other and zero.
+ mel_weights_matrix = math_ops.maximum(
+ zero_float64, math_ops.minimum(lower_slopes, upper_slopes))
+
+ # Re-add the zeroed lower bins we sliced out above.
+ mel_weights_matrix = array_ops.pad(
+ mel_weights_matrix, [[bands_to_zero, 0], [0, 0]])
+
+ # Cast to the desired type.
+ return math_ops.cast(mel_weights_matrix, dtype, name=name)