aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-02 14:29:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 14:42:33 -0700
commit6d2244e4f7b519301b8d7619330ce0f95ac4d5f9 (patch)
tree71db40d33d709c82b5c46f0beeaf5e1e4e901809 /tensorflow/contrib/signal
parentee4f13d04dd31833e34acd5ebe061c561bb5a9a1 (diff)
Improve a text comment related to MonitoredSession's hooks.
session_run_hooks.py talks about "monitors", but I'm guessing what's meant is in fact "hooks". Am I right? PiperOrigin-RevId: 170753935
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD14
-rw-r--r--tensorflow/contrib/signal/__init__.py3
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py117
-rw-r--r--tensorflow/contrib/signal/python/ops/mfcc_ops.py137
4 files changed, 271 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 8c11cf0d64..6025ec5b57 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -35,6 +35,20 @@ cuda_py_tests(
)
cuda_py_tests(
+ name = "mfcc_ops_test",
+ srcs = ["python/kernel_tests/mfcc_ops_test.py"],
+ additional_deps = [
+ ":signal_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:spectral_ops_test_util",
+ ],
+)
+
+cuda_py_tests(
name = "reconstruction_ops_test",
srcs = ["python/kernel_tests/reconstruction_ops_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 25123b097e..0f2592b0b0 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -20,6 +20,7 @@ See the @{$python/contrib.signal} guide.
@@hamming_window
@@hann_window
@@inverse_stft
+@@mfccs_from_log_mel_spectrograms
@@linear_to_mel_weight_matrix
@@overlap_and_add
@@stft
@@ -27,6 +28,7 @@ See the @{$python/contrib.signal} guide.
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
[mel]: https://en.wikipedia.org/wiki/Mel_scale
+[mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
@@ -35,6 +37,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix
+from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms
from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add
from tensorflow.contrib.signal.python.ops.shape_ops import frame
# `frame` used to be named `frames`, which is a noun and not a verb.
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py
new file mode 100644
index 0000000000..b3a8d40c13
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py
@@ -0,0 +1,117 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for mfcc_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+
+import numpy as np
+
+
+from tensorflow.contrib.signal.python.ops import mfcc_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import spectral_ops_test_util
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+# TODO(rjryan): Add scipy.fftpack to the TensorFlow build.
+def try_import(name): # pylint: disable=invalid-name
+ module = None
+ try:
+ module = importlib.import_module(name)
+ except ImportError as e:
+ tf_logging.warning("Could not import %s: %s" % (name, str(e)))
+ return module
+
+
+fftpack = try_import("scipy.fftpack")
+
+
+class DCTTest(test.TestCase):
+
+ def _np_dct2(self, signals, norm=None):
+ """Computes the DCT-II manually with NumPy."""
+ # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
+ dct_size = signals.shape[-1]
+ dct = np.zeros_like(signals)
+ for k in range(dct_size):
+ phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
+ dct[..., k] = np.sum(signals * phi, axis=-1)
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ if norm == "ortho":
+ # The orthogonal scaling includes a factor of 0.5 which we combine with
+ # the overall scaling of 2.0 to cancel.
+ dct[..., 0] *= np.sqrt(1.0 / dct_size)
+ dct[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ dct *= 2.0
+ return dct
+
+ def test_compare_to_numpy(self):
+ """Compare dct against a manual DCT-II implementation."""
+ with spectral_ops_test_util.fft_kernel_label_map():
+ with self.test_session(use_gpu=True):
+ for size in range(1, 23):
+ signals = np.random.rand(size).astype(np.float32)
+ actual_dct = mfcc_ops._dct2_1d(signals).eval()
+ expected_dct = self._np_dct2(signals)
+ self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
+
+ def test_compare_to_fftpack(self):
+ """Compare dct against scipy.fftpack.dct."""
+ if not fftpack:
+ return
+ with spectral_ops_test_util.fft_kernel_label_map():
+ with self.test_session(use_gpu=True):
+ for size in range(1, 23):
+ signal = np.random.rand(size).astype(np.float32)
+ actual_dct = mfcc_ops._dct2_1d(signal).eval()
+ expected_dct = fftpack.dct(signal, type=2)
+ self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
+
+
+# TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally
+# at Google, this code is tested against a reference implementation that follows
+# HTK conventions.
+class MFCCTest(test.TestCase):
+
+ def test_error(self):
+ # num_mel_bins must be positive.
+ with self.assertRaises(ValueError):
+ signal = array_ops.zeros((2, 3, 0))
+ mfcc_ops.mfccs_from_log_mel_spectrograms(signal)
+
+ # signal must be float32
+ with self.assertRaises(ValueError):
+ signal = array_ops.zeros((2, 3, 5), dtype=dtypes.float64)
+ mfcc_ops.mfccs_from_log_mel_spectrograms(signal)
+
+ def test_basic(self):
+ """A basic test that the op runs on random input."""
+ with spectral_ops_test_util.fft_kernel_label_map():
+ with self.test_session(use_gpu=True):
+ signal = random_ops.random_normal((2, 3, 5))
+ mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval()
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py
new file mode 100644
index 0000000000..35b6d3ad45
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py
@@ -0,0 +1,137 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Mel-Frequency Cepstral Coefficients (MFCCs) ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import spectral_ops
+
+
+# TODO(rjryan): Remove once tf.spectral.dct exists.
+def _dct2_1d(signals, name=None):
+ """Computes the type II 1D Discrete Cosine Transform (DCT) of `signals`.
+
+ Args:
+ signals: A `[..., samples]` `float32` `Tensor` containing the signals to
+ take the DCT of.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., samples]` `float32` `Tensor` containing the DCT of `signals`.
+
+ """
+ with ops.name_scope(name, 'dct', [signals]):
+ # We use the FFT to compute the DCT and TensorFlow only supports float32 for
+ # FFTs at the moment.
+ signals = ops.convert_to_tensor(signals, dtype=dtypes.float32)
+
+ axis_dim = signals.shape[-1].value or array_ops.shape(signals)[-1]
+ axis_dim_float = math_ops.to_float(axis_dim)
+ scale = 2.0 * math_ops.exp(math_ops.complex(
+ 0.0, -math.pi * math_ops.range(axis_dim_float) /
+ (2.0 * axis_dim_float)))
+
+ rfft = spectral_ops.rfft(signals, fft_length=[2 * axis_dim])[..., :axis_dim]
+ dct2 = math_ops.real(rfft * scale)
+ return dct2
+
+
+def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
+ """Computes [MFCCs][mfcc] of `log_mel_spectrograms`.
+
+ Implemented with GPU-compatible ops and supports gradients.
+
+ [Mel-Frequency Cepstral Coefficient (MFCC)][mfcc] calculation consists of
+ taking the DCT-II of a log-magnitude mel-scale spectrogram. [HTK][htk]'s MFCCs
+ use a particular scaling of the DCT-II which is almost orthogonal
+ normalization. We follow this convention.
+
+ All `num_mel_bins` MFCCs are returned and it is up to the caller to select
+ a subset of the MFCCs based on their application. For example, it is typical
+ to only use the first few for speech recognition, as this results in
+ an approximately pitch-invariant representation of the signal.
+
+ For example:
+
+ ```python
+ sample_rate = 16000.0
+ # A Tensor of [batch_size, num_samples] mono PCM samples in the range [-1, 1].
+ pcm = tf.placeholder(tf.float32, [None, None])
+
+ # A 1024-point STFT with frames of 64 ms and 75% overlap.
+ stfts = tf.contrib.signal.stft(pcm, frame_length=1024, frame_step=256,
+ fft_length=1024)
+ spectrograms = tf.abs(stft)
+
+ # Warp the linear scale spectrograms into the mel-scale.
+ num_spectrogram_bins = stfts.shape[-1].value
+ lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80
+ linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
+ num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
+ upper_edge_hertz)
+ mel_spectrograms = tf.tensordot(
+ spectrograms, linear_to_mel_weight_matrix, 1)
+ mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
+ linear_to_mel_weight_matrix.shape[-1:]))
+
+ # Compute a stabilized log to get log-magnitude mel-scale spectrograms.
+ log_mel_spectrograms = tf.log(mel_spectrograms + 1e-6)
+
+ # Compute MFCCs from log_mel_spectrograms and take the first 13.
+ mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms(
+ log_mel_spectrograms)[..., :13]
+ ```
+
+ Args:
+ log_mel_spectrograms: A `[..., num_mel_bins]` `float32` `Tensor` of
+ log-magnitude mel-scale spectrograms.
+ name: An optional name for the operation.
+ Returns:
+ A `[..., num_mel_bins]` `float32` `Tensor` of the MFCCs of
+ `log_mel_spectrograms`.
+
+ Raises:
+ ValueError: If `num_mel_bins` is not positive.
+
+ [mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+ [htk]: https://en.wikipedia.org/wiki/HTK_(software)
+ """
+ with ops.name_scope(name, 'mfccs_from_log_mel_spectrograms',
+ [log_mel_spectrograms]):
+ # Compute the DCT-II of the resulting log-magnitude mel-scale spectrogram.
+ # The DCT used in HTK scales every basis vector by sqrt(2/N), which is the
+ # scaling required for an "orthogonal" DCT-II *except* in the 0th bin, where
+ # the true orthogonal DCT (as implemented by scipy) scales by sqrt(1/N). For
+ # this reason, we don't apply orthogonal normalization and scale the DCT by
+ # `0.5 * sqrt(2/N)` manually.
+ log_mel_spectrograms = ops.convert_to_tensor(log_mel_spectrograms,
+ dtype=dtypes.float32)
+ if (log_mel_spectrograms.shape.ndims and
+ log_mel_spectrograms.shape[-1].value is not None):
+ num_mel_bins = log_mel_spectrograms.shape[-1].value
+ if num_mel_bins == 0:
+ raise ValueError('num_mel_bins must be positive. Got: %s' %
+ log_mel_spectrograms)
+ else:
+ num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1]
+ return _dct2_1d(log_mel_spectrograms) * math_ops.rsqrt(num_mel_bins * 2.0)