diff options
author | RJ Ryan <rjryan@google.com> | 2018-02-03 09:46:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-03 09:49:57 -0800 |
commit | a42450a76e43154cc3bf8977c2e9c8afb1d08621 (patch) | |
tree | b8748387a50d9ed5678ad7f49c24a75dfbd7b918 /tensorflow/contrib/signal | |
parent | 34bff30979896879815dd6fc4d77c1a37d9b98a0 (diff) |
[tf-signal] Fix exception when input shape is unknown in mfccs_from_log_mel_spectrograms.
PiperOrigin-RevId: 184400783
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/mfcc_ops.py | 2 |
2 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py index c04f1cf5ba..e7743bdcba 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.signal.python.ops import mfcc_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import spectral_ops_test_util @@ -49,6 +50,14 @@ class MFCCTest(test.TestCase): signal = random_ops.random_normal((2, 3, 5)) mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() + def test_unknown_shape(self): + """A test that the op runs when shape and rank are unknown.""" + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(use_gpu=True): + signal = array_ops.placeholder_with_default( + random_ops.random_normal((2, 3, 5)), tensor_shape.TensorShape(None)) + self.assertIsNone(signal.shape.ndims) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py index 6cef95f742..4e842f7f10 100644 --- a/tensorflow/contrib/signal/python/ops/mfcc_ops.py +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -105,4 +105,4 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] dct2 = spectral_ops.dct(log_mel_spectrograms) - return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) + return dct2 * math_ops.rsqrt(math_ops.to_float(num_mel_bins) * 2.0) |