aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-02-03 09:46:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-03 09:49:57 -0800
commita42450a76e43154cc3bf8977c2e9c8afb1d08621 (patch)
treeb8748387a50d9ed5678ad7f49c24a75dfbd7b918 /tensorflow/contrib/signal
parent34bff30979896879815dd6fc4d77c1a37d9b98a0 (diff)
[tf-signal] Fix exception when input shape is unknown in mfccs_from_log_mel_spectrograms.
PiperOrigin-RevId: 184400783
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py9
-rw-r--r--tensorflow/contrib/signal/python/ops/mfcc_ops.py2
2 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py
index c04f1cf5ba..e7743bdcba 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.signal.python.ops import mfcc_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import spectral_ops_test_util
@@ -49,6 +50,14 @@ class MFCCTest(test.TestCase):
signal = random_ops.random_normal((2, 3, 5))
mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval()
+ def test_unknown_shape(self):
+ """A test that the op runs when shape and rank are unknown."""
+ with spectral_ops_test_util.fft_kernel_label_map():
+ with self.test_session(use_gpu=True):
+ signal = array_ops.placeholder_with_default(
+ random_ops.random_normal((2, 3, 5)), tensor_shape.TensorShape(None))
+ self.assertIsNone(signal.shape.ndims)
+ mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval()
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py
index 6cef95f742..4e842f7f10 100644
--- a/tensorflow/contrib/signal/python/ops/mfcc_ops.py
+++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py
@@ -105,4 +105,4 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1]
dct2 = spectral_ops.dct(log_mel_spectrograms)
- return dct2 * math_ops.rsqrt(num_mel_bins * 2.0)
+ return dct2 * math_ops.rsqrt(math_ops.to_float(num_mel_bins) * 2.0)