aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
index 345eb6cfaa..f4348e80ea 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
@@ -53,7 +53,8 @@ def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
- upper_edge_hertz=3800.0):
+ upper_edge_hertz=3800.0,
+ unused_dtype=None):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Copied from
@@ -132,9 +133,9 @@ class LinearToMelTest(test.TestCase):
# lower_edge_hertz, upper_edge_hertz) to test.
configs = [
# Defaults.
- (20, 129, 8000.0, 125.0, 3800.0),
+ (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
- (80, 1025, 24000.0, 80.0, 12000.0)
+ (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64)
]
with self.test_session(use_gpu=True):
for config in configs:
@@ -143,7 +144,8 @@ class LinearToMelTest(test.TestCase):
self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6)
def test_dtypes(self):
- for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ # LinSpace is not supported for tf.float16.
+ for dtype in (dtypes.bfloat16, dtypes.float32, dtypes.float64):
self.assertEqual(dtype,
mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
@@ -167,7 +169,8 @@ class LinearToMelTest(test.TestCase):
def test_constant_folding(self):
"""Mel functions should be constant foldable."""
- for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler.
+ for dtype in (dtypes.float32, dtypes.float64):
g = ops.Graph()
with g.as_default():
mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype)