diff options
author | 2017-10-03 17:50:55 -0700 | |
---|---|---|
committer | 2017-10-03 17:54:01 -0700 | |
commit | add6d2d03cd89668eb515b8c012abece2bfaab85 (patch) | |
tree | 0498deb35afa80dc6460691a2797c93f0fff4322 /tensorflow/contrib/signal | |
parent | b959da92f945129596d2cec5bf0c727b213beacf (diff) |
[tf-signal] Use tf.spectral.dct in mfccs_from_log_mel_spectrograms instead of a private implementation.
PiperOrigin-RevId: 170943986
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py | 63 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/mfcc_ops.py | 35 |
2 files changed, 3 insertions, 95 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py index b3a8d40c13..c04f1cf5ba 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -18,75 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import importlib - -import numpy as np - - from tensorflow.contrib.signal.python.ops import mfcc_ops from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging - - -# TODO(rjryan): Add scipy.fftpack to the TensorFlow build. -def try_import(name): # pylint: disable=invalid-name - module = None - try: - module = importlib.import_module(name) - except ImportError as e: - tf_logging.warning("Could not import %s: %s" % (name, str(e))) - return module - - -fftpack = try_import("scipy.fftpack") - - -class DCTTest(test.TestCase): - - def _np_dct2(self, signals, norm=None): - """Computes the DCT-II manually with NumPy.""" - # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 - dct_size = signals.shape[-1] - dct = np.zeros_like(signals) - for k in range(dct_size): - phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) - dct[..., k] = np.sum(signals * phi, axis=-1) - # SciPy's `dct` has a scaling factor of 2.0 which we follow. - # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src - if norm == "ortho": - # The orthogonal scaling includes a factor of 0.5 which we combine with - # the overall scaling of 2.0 to cancel. - dct[..., 0] *= np.sqrt(1.0 / dct_size) - dct[..., 1:] *= np.sqrt(2.0 / dct_size) - else: - dct *= 2.0 - return dct - - def test_compare_to_numpy(self): - """Compare dct against a manual DCT-II implementation.""" - with spectral_ops_test_util.fft_kernel_label_map(): - with self.test_session(use_gpu=True): - for size in range(1, 23): - signals = np.random.rand(size).astype(np.float32) - actual_dct = mfcc_ops._dct2_1d(signals).eval() - expected_dct = self._np_dct2(signals) - self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4) - - def test_compare_to_fftpack(self): - """Compare dct against scipy.fftpack.dct.""" - if not fftpack: - return - with spectral_ops_test_util.fft_kernel_label_map(): - with self.test_session(use_gpu=True): - for size in range(1, 23): - signal = np.random.rand(size).astype(np.float32) - actual_dct = mfcc_ops._dct2_1d(signal).eval() - expected_dct = fftpack.dct(signal, type=2) - self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4) # TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py index 35b6d3ad45..7bc7b57cd4 100644 --- a/tensorflow/contrib/signal/python/ops/mfcc_ops.py +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -27,35 +25,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops -# TODO(rjryan): Remove once tf.spectral.dct exists. -def _dct2_1d(signals, name=None): - """Computes the type II 1D Discrete Cosine Transform (DCT) of `signals`. - - Args: - signals: A `[..., samples]` `float32` `Tensor` containing the signals to - take the DCT of. - name: An optional name for the operation. - - Returns: - A `[..., samples]` `float32` `Tensor` containing the DCT of `signals`. - - """ - with ops.name_scope(name, 'dct', [signals]): - # We use the FFT to compute the DCT and TensorFlow only supports float32 for - # FFTs at the moment. - signals = ops.convert_to_tensor(signals, dtype=dtypes.float32) - - axis_dim = signals.shape[-1].value or array_ops.shape(signals)[-1] - axis_dim_float = math_ops.to_float(axis_dim) - scale = 2.0 * math_ops.exp(math_ops.complex( - 0.0, -math.pi * math_ops.range(axis_dim_float) / - (2.0 * axis_dim_float))) - - rfft = spectral_ops.rfft(signals, fft_length=[2 * axis_dim])[..., :axis_dim] - dct2 = math_ops.real(rfft * scale) - return dct2 - - def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): """Computes [MFCCs][mfcc] of `log_mel_spectrograms`. @@ -134,4 +103,6 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): log_mel_spectrograms) else: num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] - return _dct2_1d(log_mel_spectrograms) * math_ops.rsqrt(num_mel_bins * 2.0) + + dct2 = spectral_ops.dct(log_mel_spectrograms) + return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) |