aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal/python/ops/mel_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/signal/python/ops/mel_ops.py')
-rw-r--r--tensorflow/contrib/signal/python/ops/mel_ops.py24
1 files changed, 10 insertions, 14 deletions
diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py
index 1e84006116..062d84aea1 100644
--- a/tensorflow/contrib/signal/python/ops/mel_ops.py
+++ b/tensorflow/contrib/signal/python/ops/mel_ops.py
@@ -151,22 +151,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
_validate_arguments(num_mel_bins, sample_rate,
lower_edge_hertz, upper_edge_hertz, dtype)
- # To preserve accuracy, we compute the matrix at float64 precision and then
- # cast to `dtype` at the end. This function can be constant folded by graph
- # optimization since there are no Tensor inputs.
+ # This function can be constant folded by graph optimization since there are
+ # no Tensor inputs.
sample_rate = ops.convert_to_tensor(
- sample_rate, dtypes.float64, name='sample_rate')
+ sample_rate, dtype, name='sample_rate')
lower_edge_hertz = ops.convert_to_tensor(
- lower_edge_hertz, dtypes.float64, name='lower_edge_hertz')
+ lower_edge_hertz, dtype, name='lower_edge_hertz')
upper_edge_hertz = ops.convert_to_tensor(
- upper_edge_hertz, dtypes.float64, name='upper_edge_hertz')
- zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64)
+ upper_edge_hertz, dtype, name='upper_edge_hertz')
+ zero = ops.convert_to_tensor(0.0, dtype)
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = math_ops.linspace(
- zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
+ zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
spectrogram_bins_mel = array_ops.expand_dims(
_hertz_to_mel(linear_frequencies), 1)
@@ -193,11 +192,8 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
# Intersect the line segments with each other and zero.
mel_weights_matrix = math_ops.maximum(
- zero_float64, math_ops.minimum(lower_slopes, upper_slopes))
+ zero, math_ops.minimum(lower_slopes, upper_slopes))
# Re-add the zeroed lower bins we sliced out above.
- mel_weights_matrix = array_ops.pad(
- mel_weights_matrix, [[bands_to_zero, 0], [0, 0]])
-
- # Cast to the desired type.
- return math_ops.cast(mel_weights_matrix, dtype, name=name)
+ return array_ops.pad(
+ mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name)