aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-07-20 14:47:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 14:52:33 -0700
commit1711a9a08ce29029c66924f880fa1e619aed10aa (patch)
tree6ddc66b9da44efd45477dc8f280f911b6bf544af
parent7e8a83543b7eb36647894453129f15eeec60b3ba (diff)
Remove float64 math in linear_to_mel_weight_matrix.
This was causing portability problems for platforms that do not support float64. Callers who want higher precision can simply pass tf.float64 as the dtype. PiperOrigin-RevId: 205457007
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py13
-rw-r--r--tensorflow/contrib/signal/python/ops/mel_ops.py24
2 files changed, 18 insertions, 19 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
index 345eb6cfaa..f4348e80ea 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
@@ -53,7 +53,8 @@ def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
- upper_edge_hertz=3800.0):
+ upper_edge_hertz=3800.0,
+ unused_dtype=None):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Copied from
@@ -132,9 +133,9 @@ class LinearToMelTest(test.TestCase):
# lower_edge_hertz, upper_edge_hertz) to test.
configs = [
# Defaults.
- (20, 129, 8000.0, 125.0, 3800.0),
+ (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
- (80, 1025, 24000.0, 80.0, 12000.0)
+ (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64)
]
with self.test_session(use_gpu=True):
for config in configs:
@@ -143,7 +144,8 @@ class LinearToMelTest(test.TestCase):
self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6)
def test_dtypes(self):
- for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ # LinSpace is not supported for tf.float16.
+ for dtype in (dtypes.bfloat16, dtypes.float32, dtypes.float64):
self.assertEqual(dtype,
mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
@@ -167,7 +169,8 @@ class LinearToMelTest(test.TestCase):
def test_constant_folding(self):
"""Mel functions should be constant foldable."""
- for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler.
+ for dtype in (dtypes.float32, dtypes.float64):
g = ops.Graph()
with g.as_default():
mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype)
diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py
index 1e84006116..062d84aea1 100644
--- a/tensorflow/contrib/signal/python/ops/mel_ops.py
+++ b/tensorflow/contrib/signal/python/ops/mel_ops.py
@@ -151,22 +151,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
_validate_arguments(num_mel_bins, sample_rate,
lower_edge_hertz, upper_edge_hertz, dtype)
- # To preserve accuracy, we compute the matrix at float64 precision and then
- # cast to `dtype` at the end. This function can be constant folded by graph
- # optimization since there are no Tensor inputs.
+ # This function can be constant folded by graph optimization since there are
+ # no Tensor inputs.
sample_rate = ops.convert_to_tensor(
- sample_rate, dtypes.float64, name='sample_rate')
+ sample_rate, dtype, name='sample_rate')
lower_edge_hertz = ops.convert_to_tensor(
- lower_edge_hertz, dtypes.float64, name='lower_edge_hertz')
+ lower_edge_hertz, dtype, name='lower_edge_hertz')
upper_edge_hertz = ops.convert_to_tensor(
- upper_edge_hertz, dtypes.float64, name='upper_edge_hertz')
- zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64)
+ upper_edge_hertz, dtype, name='upper_edge_hertz')
+ zero = ops.convert_to_tensor(0.0, dtype)
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = math_ops.linspace(
- zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
+ zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
spectrogram_bins_mel = array_ops.expand_dims(
_hertz_to_mel(linear_frequencies), 1)
@@ -193,11 +192,8 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
# Intersect the line segments with each other and zero.
mel_weights_matrix = math_ops.maximum(
- zero_float64, math_ops.minimum(lower_slopes, upper_slopes))
+ zero, math_ops.minimum(lower_slopes, upper_slopes))
# Re-add the zeroed lower bins we sliced out above.
- mel_weights_matrix = array_ops.pad(
- mel_weights_matrix, [[bands_to_zero, 0], [0, 0]])
-
- # Cast to the desired type.
- return math_ops.cast(mel_weights_matrix, dtype, name=name)
+ return array_ops.pad(
+ mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name)