diff options
author | Yifei Feng <yifeif@google.com> | 2018-04-23 21:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 21:21:38 -0700 |
commit | 22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch) | |
tree | d16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/signal/python | |
parent | 24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff) |
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/signal/python')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/mel_ops.py | 16 |
2 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index 35c4b5bec1..345eb6cfaa 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import mel_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test # mel spectrum constants and functions. @@ -173,6 +174,18 @@ class LinearToMelTest(test.TestCase): rewritten_graph = test_util.grappler_optimize(g, [mel_matrix]) self.assertEqual(1, len(rewritten_graph.node)) + def test_num_spectrogram_bins_dynamic(self): + with self.test_session(use_gpu=True): + num_spectrogram_bins = array_ops.placeholder(shape=(), + dtype=dtypes.int32) + mel_matrix_np = spectrogram_to_mel_matrix( + 20, 129, 8000.0, 125.0, 3800.0) + mel_matrix = mel_ops.linear_to_mel_weight_matrix( + 20, num_spectrogram_bins, 8000.0, 125.0, 3800.0) + self.assertAllClose( + mel_matrix_np, + mel_matrix.eval(feed_dict={num_spectrogram_bins: 129}), atol=3e-6) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py index d1a36548d9..1e84006116 100644 --- a/tensorflow/contrib/signal/python/ops/mel_ops.py +++ b/tensorflow/contrib/signal/python/ops/mel_ops.py @@ -64,14 +64,11 @@ def _hertz_to_mel(frequencies_hertz, name=None): 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) -def _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate, +def _validate_arguments(num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype): """Checks the inputs to linear_to_mel_weight_matrix.""" if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) - if num_spectrogram_bins <= 0: - raise ValueError('num_spectrogram_bins must be positive. Got: %s' % - num_spectrogram_bins) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if lower_edge_hertz < 0.0: @@ -122,9 +119,9 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, Args: num_mel_bins: Python int. How many bands in the resulting mel spectrum. - num_spectrogram_bins: Python int. How many bins there are in the source - spectrogram data, which is understood to be `fft_size // 2 + 1`, i.e. the - spectrogram only contains the nonredundant FFT bins. + num_spectrogram_bins: An integer `Tensor`. How many bins there are in the + source spectrogram data, which is understood to be `fft_size // 2 + 1`, + i.e. the spectrogram only contains the nonredundant FFT bins. sample_rate: Python float. Samples per second of the input signal used to create the spectrogram. We need this to figure out the actual frequencies for each spectrogram bin, which dictates how they are mapped into the mel @@ -148,7 +145,10 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, [mel]: https://en.wikipedia.org/wiki/Mel_scale """ with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name: - _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate, + # Note: As num_spectrogram_bins is passed to `math_ops.linspace` + # and the validation is already done in linspace (both in shape function + # and in kernel), there is no need to validate num_spectrogram_bins here. + _validate_arguments(num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype) # To preserve accuracy, we compute the matrix at float64 precision and then |