aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/models.py')
-rw-r--r--tensorflow/examples/speech_commands/models.py302
1 files changed, 237 insertions, 65 deletions
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index ab611f414a..65ae3b1511 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -24,9 +24,21 @@ import math
import tensorflow as tf
+def _next_power_of_two(x):
+ """Calculates the smallest enclosing power of two for an input.
+
+ Args:
+ x: Positive float or integer number.
+
+ Returns:
+ Next largest power of two integer.
+ """
+ return 1 if x == 0 else 2**(int(x) - 1).bit_length()
+
+
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
- window_size_ms, window_stride_ms,
- dct_coefficient_count):
+ window_size_ms, window_stride_ms, feature_bin_count,
+ preprocess):
"""Calculates common settings needed for all models.
Args:
@@ -35,10 +47,14 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
clip_duration_ms: Length of each audio clip to be analyzed.
window_size_ms: Duration of frequency analysis window.
window_stride_ms: How far to move in time between frequency windows.
- dct_coefficient_count: Number of frequency bins to use for analysis.
+ feature_bin_count: Number of frequency bins to use for analysis.
+ preprocess: How the spectrogram is processed to produce features.
Returns:
Dictionary containing common settings.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
@@ -48,16 +64,28 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
- fingerprint_size = dct_coefficient_count * spectrogram_length
+ if preprocess == 'average':
+ fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
+ average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
+ fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
+ elif preprocess == 'mfcc':
+ average_window_width = -1
+ fingerprint_width = feature_bin_count
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (preprocess))
+ fingerprint_size = fingerprint_width * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
- 'dct_coefficient_count': dct_coefficient_count,
+ 'fingerprint_width': fingerprint_width,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
+ 'preprocess': preprocess,
+ 'average_window_width': average_window_width,
}
@@ -106,10 +134,14 @@ def create_model(fingerprint_input, model_settings, model_architecture,
elif model_architecture == 'low_latency_svdf':
return create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings)
+ elif model_architecture == 'tiny_conv':
+ return create_tiny_conv_model(fingerprint_input, model_settings,
+ is_training)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
- ' "low_latency_conv, or "low_latency_svdf"')
+ ' "low_latency_conv, "low_latency_svdf",' +
+ ' or "tiny_conv"')
def load_variables_from_checkpoint(sess, start_checkpoint):
@@ -152,9 +184,12 @@ def create_single_fc_model(fingerprint_input, model_settings, is_training):
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
- weights = tf.Variable(
- tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
- bias = tf.Variable(tf.zeros([label_count]))
+ weights = tf.get_variable(
+ name='weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.001),
+ shape=[fingerprint_size, label_count])
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[label_count])
logits = tf.matmul(fingerprint_input, weights) + bias
if is_training:
return logits, dropout_prob
@@ -212,18 +247,21 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 20
first_filter_count = 64
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
@@ -235,14 +273,17 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
second_filter_width = 4
second_filter_height = 10
second_filter_count = 64
- second_weights = tf.Variable(
- tf.truncated_normal(
- [
- second_filter_height, second_filter_width, first_filter_count,
- second_filter_count
- ],
- stddev=0.01))
- second_bias = tf.Variable(tf.zeros([second_filter_count]))
+ second_weights = tf.get_variable(
+ name='second_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[
+ second_filter_height, second_filter_width, first_filter_count,
+ second_filter_count
+ ])
+ second_bias = tf.get_variable(
+ name='second_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_filter_count])
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
'SAME') + second_bias
second_relu = tf.nn.relu(second_conv)
@@ -259,10 +300,14 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
flattened_second_conv = tf.reshape(second_dropout,
[-1, second_conv_element_count])
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_conv_element_count, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer,
+ shape=[second_conv_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -318,7 +363,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
@@ -327,11 +372,14 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
first_filter_count = 186
first_filter_stride_x = 1
first_filter_stride_y = 1
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
1, first_filter_stride_y, first_filter_stride_x, 1
], 'VALID') + first_bias
@@ -351,30 +399,42 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
flattened_first_conv = tf.reshape(first_dropout,
[-1, first_conv_element_count])
first_fc_output_channels = 128
- first_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_conv_element_count, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_conv_element_count, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 128
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -422,7 +482,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
The node is expected to produce a 2D Tensor of shape:
- [batch, model_settings['dct_coefficient_count'] *
+ [batch, model_settings['fingerprint_width'] *
model_settings['spectrogram_length']]
with the features corresponding to the same time slot arranged contiguously,
and the oldest slot at index [:, 0], and newest at [:, -1].
@@ -440,7 +500,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
# Validation.
@@ -462,8 +522,11 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
num_filters = rank * num_units
# Create the runtime memory: [num_filters, batch, input_time_size]
batch = 1
- memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
- trainable=False, name='runtime-memory')
+ memory = tf.get_variable(
+ initializer=tf.zeros_initializer,
+ shape=[num_filters, batch, input_time_size],
+ trainable=False,
+ name='runtime-memory')
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
@@ -483,8 +546,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
# Create the frequency filters.
- weights_frequency = tf.Variable(
- tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
+ weights_frequency = tf.get_variable(
+ name='weights_frequency',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[input_frequency_size, num_filters])
# Expand to add input channels dimensions.
# weights_frequency: [input_frequency_size, 1, num_filters]
weights_frequency = tf.expand_dims(weights_frequency, 1)
@@ -506,8 +571,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
activations_time = new_memory
# Create the time filters.
- weights_time = tf.Variable(
- tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
+ weights_time = tf.get_variable(
+ name='weights_time',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_filters, input_time_size])
# Apply the time filter on the outputs of the feature filters.
# weights_time: [num_filters, input_time_size, 1]
# outputs: [num_filters, batch, 1]
@@ -524,7 +591,8 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
units_output = tf.transpose(units_output)
# Appy bias.
- bias = tf.Variable(tf.zeros([num_units]))
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[num_units])
first_bias = tf.nn.bias_add(units_output, bias)
# Relu.
@@ -536,31 +604,135 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
first_dropout = first_relu
first_fc_output_channels = 256
- first_fc_weights = tf.Variable(
- tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_units, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 256
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc
+
+
+def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
+ """Builds a convolutional model aimed at microcontrollers.
+
+ Devices like DSPs and microcontrollers can have very small amounts of
+ memory and limited processing power. This model is designed to use less
+ than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
+
+ Here's the layout of the graph:
+
+ (fingerprint_input)
+ v
+ [Conv2D]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+ [Relu]
+ v
+ [MatMul]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+
+ This doesn't produce particularly accurate results, but it's designed to be
+ used as the first stage of a pipeline, running on a low-energy piece of
+ hardware that can always be on, and then wake higher-power chips when a
+ possible utterance has been found, so that more accurate analysis can be done.
+
+ During training, a dropout node is introduced after the relu, controlled by a
+ placeholder.
+
+ Args:
+ fingerprint_input: TensorFlow node that will output audio feature vectors.
+ model_settings: Dictionary of information about the model.
+ is_training: Whether the model is going to be used for training.
+
+ Returns:
+ TensorFlow node outputting logits results, and optionally a dropout
+ placeholder.
+ """
+ if is_training:
+ dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+ input_frequency_size = model_settings['fingerprint_width']
+ input_time_size = model_settings['spectrogram_length']
+ fingerprint_4d = tf.reshape(fingerprint_input,
+ [-1, input_time_size, input_frequency_size, 1])
+ first_filter_width = 8
+ first_filter_height = 10
+ first_filter_count = 8
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
+ first_conv_stride_x = 2
+ first_conv_stride_y = 2
+ first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
+ [1, first_conv_stride_y, first_conv_stride_x, 1],
+ 'SAME') + first_bias
+ first_relu = tf.nn.relu(first_conv)
+ if is_training:
+ first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+ else:
+ first_dropout = first_relu
+ first_dropout_shape = first_dropout.get_shape()
+ first_dropout_output_width = first_dropout_shape[2]
+ first_dropout_output_height = first_dropout_shape[1]
+ first_dropout_element_count = int(
+ first_dropout_output_width * first_dropout_output_height *
+ first_filter_count)
+ flattened_first_dropout = tf.reshape(first_dropout,
+ [-1, first_dropout_element_count])
+ label_count = model_settings['label_count']
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_dropout_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
+ final_fc = (
+ tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
+ if is_training:
+ return final_fc, dropout_prob
+ else:
+ return final_fc