aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal/python
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/signal/python
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/signal/python')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py13
-rw-r--r--tensorflow/contrib/signal/python/ops/mel_ops.py16
2 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
index 35c4b5bec1..345eb6cfaa 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import mel_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
# mel spectrum constants and functions.
@@ -173,6 +174,18 @@ class LinearToMelTest(test.TestCase):
rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
self.assertEqual(1, len(rewritten_graph.node))
+ def test_num_spectrogram_bins_dynamic(self):
+ with self.test_session(use_gpu=True):
+ num_spectrogram_bins = array_ops.placeholder(shape=(),
+ dtype=dtypes.int32)
+ mel_matrix_np = spectrogram_to_mel_matrix(
+ 20, 129, 8000.0, 125.0, 3800.0)
+ mel_matrix = mel_ops.linear_to_mel_weight_matrix(
+ 20, num_spectrogram_bins, 8000.0, 125.0, 3800.0)
+ self.assertAllClose(
+ mel_matrix_np,
+ mel_matrix.eval(feed_dict={num_spectrogram_bins: 129}), atol=3e-6)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py
index d1a36548d9..1e84006116 100644
--- a/tensorflow/contrib/signal/python/ops/mel_ops.py
+++ b/tensorflow/contrib/signal/python/ops/mel_ops.py
@@ -64,14 +64,11 @@ def _hertz_to_mel(frequencies_hertz, name=None):
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
-def _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
+def _validate_arguments(num_mel_bins, sample_rate,
lower_edge_hertz, upper_edge_hertz, dtype):
"""Checks the inputs to linear_to_mel_weight_matrix."""
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
- if num_spectrogram_bins <= 0:
- raise ValueError('num_spectrogram_bins must be positive. Got: %s' %
- num_spectrogram_bins)
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if lower_edge_hertz < 0.0:
@@ -122,9 +119,9 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
Args:
num_mel_bins: Python int. How many bands in the resulting mel spectrum.
- num_spectrogram_bins: Python int. How many bins there are in the source
- spectrogram data, which is understood to be `fft_size // 2 + 1`, i.e. the
- spectrogram only contains the nonredundant FFT bins.
+ num_spectrogram_bins: An integer `Tensor`. How many bins there are in the
+ source spectrogram data, which is understood to be `fft_size // 2 + 1`,
+ i.e. the spectrogram only contains the nonredundant FFT bins.
sample_rate: Python float. Samples per second of the input signal used to
create the spectrogram. We need this to figure out the actual frequencies
for each spectrogram bin, which dictates how they are mapped into the mel
@@ -148,7 +145,10 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
[mel]: https://en.wikipedia.org/wiki/Mel_scale
"""
with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
- _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
+ # Note: As num_spectrogram_bins is passed to `math_ops.linspace`
+ # and the validation is already done in linspace (both in shape function
+ # and in kernel), there is no need to validate num_spectrogram_bins here.
+ _validate_arguments(num_mel_bins, sample_rate,
lower_edge_hertz, upper_edge_hertz, dtype)
# To preserve accuracy, we compute the matrix at float64 precision and then