aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2018-07-11 15:48:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 15:51:45 -0700
commit72e0e3d383321d160c5ae7bffcf66bed726d8a62 (patch)
tree491f63d2a399164aae40ac58eae590eec9de9854 /tensorflow/examples
parentf2fa55c8d2f94bd186fc6c47b8ce00fb87c22aaf (diff)
Add support for microcontroller-scale audio speech models
PiperOrigin-RevId: 204204001
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/speech_commands/BUILD1
-rw-r--r--tensorflow/examples/speech_commands/freeze.py64
-rw-r--r--tensorflow/examples/speech_commands/freeze_test.py54
-rw-r--r--tensorflow/examples/speech_commands/generate_streaming_test_wav.py10
-rw-r--r--tensorflow/examples/speech_commands/input_data.py135
-rw-r--r--tensorflow/examples/speech_commands/input_data_test.py87
-rw-r--r--tensorflow/examples/speech_commands/models.py302
-rw-r--r--tensorflow/examples/speech_commands/models_test.py40
-rw-r--r--tensorflow/examples/speech_commands/train.py58
9 files changed, 562 insertions, 189 deletions
diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD
index 13bca34a86..7a44e2ee4f 100644
--- a/tensorflow/examples/speech_commands/BUILD
+++ b/tensorflow/examples/speech_commands/BUILD
@@ -56,6 +56,7 @@ tf_py_test(
srcs = ["input_data_test.py"],
additional_deps = [
":input_data",
+ ":models",
"//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py
index c8671d9c41..7657b23c60 100644
--- a/tensorflow/examples/speech_commands/freeze.py
+++ b/tensorflow/examples/speech_commands/freeze.py
@@ -54,7 +54,7 @@ FLAGS = None
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms, window_size_ms, window_stride_ms,
- dct_coefficient_count, model_architecture):
+ feature_bin_count, model_architecture, preprocess):
"""Creates an audio model with the nodes needed for inference.
Uses the supplied arguments to create a model, and inserts the input and
@@ -67,14 +67,19 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms: How often to run recognition. Useful for models with cache.
window_size_ms: Time slice duration to estimate frequencies from.
window_stride_ms: How far apart time slices should be.
- dct_coefficient_count: Number of frequency bands to analyze.
+ feature_bin_count: Number of frequency bands to analyze.
model_architecture: Name of the kind of model to generate.
+ preprocess: How the spectrogram is processed to produce features, for
+ example 'mfcc' or 'average'.
+
+ Raises:
+ Exception: If the preprocessing mode isn't recognized.
"""
words_list = input_data.prepare_words_list(wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
- window_stride_ms, dct_coefficient_count)
+ window_stride_ms, feature_bin_count, preprocess)
runtime_settings = {'clip_stride_ms': clip_stride_ms}
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
@@ -88,15 +93,25 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
- fingerprint_input = contrib_audio.mfcc(
- spectrogram,
- decoded_sample_data.sample_rate,
- dct_coefficient_count=dct_coefficient_count)
- fingerprint_frequency_size = model_settings['dct_coefficient_count']
- fingerprint_time_size = model_settings['spectrogram_length']
- reshaped_input = tf.reshape(fingerprint_input, [
- -1, fingerprint_time_size * fingerprint_frequency_size
- ])
+
+ if preprocess == 'average':
+ fingerprint_input = tf.nn.pool(
+ tf.expand_dims(spectrogram, -1),
+ window_shape=[1, model_settings['average_window_width']],
+ strides=[1, model_settings['average_window_width']],
+ pooling_type='AVG',
+ padding='SAME')
+ elif preprocess == 'mfcc':
+ fingerprint_input = contrib_audio.mfcc(
+ spectrogram,
+ sample_rate,
+ dct_coefficient_count=model_settings['fingerprint_width'])
+ else:
+ raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (preprocess))
+
+ fingerprint_size = model_settings['fingerprint_size']
+ reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
logits = models.create_model(
reshaped_input, model_settings, model_architecture, is_training=False,
@@ -110,10 +125,12 @@ def main(_):
# Create the model and load its weights.
sess = tf.InteractiveSession()
- create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
- FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
- FLAGS.window_size_ms, FLAGS.window_stride_ms,
- FLAGS.dct_coefficient_count, FLAGS.model_architecture)
+ create_inference_graph(
+ FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
+ FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
+ FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
+ if FLAGS.quantize:
+ tf.contrib.quantize.create_training_graph(quant_delay=0)
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
@@ -155,10 +172,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--start_checkpoint',
type=str,
@@ -176,5 +194,15 @@ if __name__ == '__main__':
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--output_file', type=str, help='Where to save the frozen graph.')
+ parser.add_argument(
+ '--quantize',
+ type=bool,
+ default=False,
+ help='Whether to train the model for eight-bit deployment')
+ parser.add_argument(
+ '--preprocess',
+ type=str,
+ default='mfcc',
+ help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index 97c6eac675..c8de6c2152 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -24,14 +24,62 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
- def testCreateInferenceGraph(self):
+ def testCreateInferenceGraphWithMfcc(self):
with self.test_session() as sess:
- freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
- 40, 'conv')
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='mfcc')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(1, ops.count('Mfcc'))
+
+ def testCreateInferenceGraphWithoutMfcc(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
+
+ def testFeatureBinCount(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=80,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
if __name__ == '__main__':
diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
index 053206ae2f..9858906927 100644
--- a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
+++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
@@ -87,11 +87,12 @@ def main(_):
words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
- FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
+ FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count,
+ 'mfcc')
audio_processor = input_data.AudioProcessor(
'', FLAGS.data_dir, FLAGS.silence_percentage, 10,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
- FLAGS.testing_percentage, model_settings)
+ FLAGS.testing_percentage, model_settings, FLAGS.data_dir)
output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
@@ -242,10 +243,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--wanted_words',
type=str,
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py
index 63dd18457f..30f2cfa9fe 100644
--- a/tensorflow/examples/speech_commands/input_data.py
+++ b/tensorflow/examples/speech_commands/input_data.py
@@ -153,14 +153,14 @@ class AudioProcessor(object):
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
- model_settings):
+ model_settings, summaries_dir):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
- self.prepare_processing_graph(model_settings)
+ self.prepare_processing_graph(model_settings, summaries_dir)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
"""Download and extract data set tar file.
@@ -325,7 +325,7 @@ class AudioProcessor(object):
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
- def prepare_processing_graph(self, model_settings):
+ def prepare_processing_graph(self, model_settings, summaries_dir):
"""Builds a TensorFlow graph to apply the input distortions.
Creates a graph that loads a WAVE file, decodes it, scales the volume,
@@ -341,48 +341,88 @@ class AudioProcessor(object):
- time_shift_offset_placeholder_: How much to move the clip in time.
- background_data_placeholder_: PCM sample data for background noise.
- background_volume_placeholder_: Loudness of mixed-in background.
- - mfcc_: Output 2D fingerprint of processed audio.
+ - output_: Output 2D fingerprint of processed audio.
Args:
model_settings: Information about the current model being trained.
+ summaries_dir: Path to save training summary information to.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
- desired_samples = model_settings['desired_samples']
- self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
- wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
- wav_decoder = contrib_audio.decode_wav(
- wav_loader, desired_channels=1, desired_samples=desired_samples)
- # Allow the audio sample's volume to be adjusted.
- self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
- scaled_foreground = tf.multiply(wav_decoder.audio,
- self.foreground_volume_placeholder_)
- # Shift the sample's start position, and pad any gaps with zeros.
- self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
- self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
- padded_foreground = tf.pad(
- scaled_foreground,
- self.time_shift_padding_placeholder_,
- mode='CONSTANT')
- sliced_foreground = tf.slice(padded_foreground,
- self.time_shift_offset_placeholder_,
- [desired_samples, -1])
- # Mix in background noise.
- self.background_data_placeholder_ = tf.placeholder(tf.float32,
- [desired_samples, 1])
- self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
- background_mul = tf.multiply(self.background_data_placeholder_,
- self.background_volume_placeholder_)
- background_add = tf.add(background_mul, sliced_foreground)
- background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
- # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
- spectrogram = contrib_audio.audio_spectrogram(
- background_clamp,
- window_size=model_settings['window_size_samples'],
- stride=model_settings['window_stride_samples'],
- magnitude_squared=True)
- self.mfcc_ = contrib_audio.mfcc(
- spectrogram,
- wav_decoder.sample_rate,
- dct_coefficient_count=model_settings['dct_coefficient_count'])
+ with tf.get_default_graph().name_scope('data'):
+ desired_samples = model_settings['desired_samples']
+ self.wav_filename_placeholder_ = tf.placeholder(
+ tf.string, [], name='wav_filename')
+ wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
+ wav_decoder = contrib_audio.decode_wav(
+ wav_loader, desired_channels=1, desired_samples=desired_samples)
+ # Allow the audio sample's volume to be adjusted.
+ self.foreground_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='foreground_volume')
+ scaled_foreground = tf.multiply(wav_decoder.audio,
+ self.foreground_volume_placeholder_)
+ # Shift the sample's start position, and pad any gaps with zeros.
+ self.time_shift_padding_placeholder_ = tf.placeholder(
+ tf.int32, [2, 2], name='time_shift_padding')
+ self.time_shift_offset_placeholder_ = tf.placeholder(
+ tf.int32, [2], name='time_shift_offset')
+ padded_foreground = tf.pad(
+ scaled_foreground,
+ self.time_shift_padding_placeholder_,
+ mode='CONSTANT')
+ sliced_foreground = tf.slice(padded_foreground,
+ self.time_shift_offset_placeholder_,
+ [desired_samples, -1])
+ # Mix in background noise.
+ self.background_data_placeholder_ = tf.placeholder(
+ tf.float32, [desired_samples, 1], name='background_data')
+ self.background_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='background_volume')
+ background_mul = tf.multiply(self.background_data_placeholder_,
+ self.background_volume_placeholder_)
+ background_add = tf.add(background_mul, sliced_foreground)
+ background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
+ # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
+ spectrogram = contrib_audio.audio_spectrogram(
+ background_clamp,
+ window_size=model_settings['window_size_samples'],
+ stride=model_settings['window_stride_samples'],
+ magnitude_squared=True)
+ tf.summary.image(
+ 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
+ # The number of buckets in each FFT row in the spectrogram will depend on
+ # how many input samples there are in each window. This can be quite
+ # large, with a 160 sample window producing 127 buckets for example. We
+ # don't need this level of detail for classification, so we often want to
+ # shrink them down to produce a smaller result. That's what this section
+ # implements. One method is to use average pooling to merge adjacent
+ # buckets, but a more sophisticated approach is to apply the MFCC
+ # algorithm to shrink the representation.
+ if model_settings['preprocess'] == 'average':
+ self.output_ = tf.nn.pool(
+ tf.expand_dims(spectrogram, -1),
+ window_shape=[1, model_settings['average_window_width']],
+ strides=[1, model_settings['average_window_width']],
+ pooling_type='AVG',
+ padding='SAME')
+ tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1)
+ elif model_settings['preprocess'] == 'mfcc':
+ self.output_ = contrib_audio.mfcc(
+ spectrogram,
+ wav_decoder.sample_rate,
+ dct_coefficient_count=model_settings['fingerprint_width'])
+ tf.summary.image(
+ 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (model_settings['preprocess']))
+
+ # Merge all the summaries and write them out to /tmp/retrain_logs (by
+ # default)
+ self.merged_summaries_ = tf.summary.merge_all(scope='data')
+ self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data',
+ tf.get_default_graph())
def set_size(self, mode):
"""Calculates the number of samples in the dataset partition.
@@ -418,6 +458,9 @@ class AudioProcessor(object):
Returns:
List of sample data for the transformed samples, and list of label indexes
+
+ Raises:
+ ValueError: If background samples are too short.
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
@@ -460,6 +503,11 @@ class AudioProcessor(object):
if use_background or sample['label'] == SILENCE_LABEL:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
+ if len(background_samples) <= model_settings['desired_samples']:
+ raise ValueError(
+ 'Background sample is too short! Need more than %d'
+ ' samples but only %d were found' %
+ (model_settings['desired_samples'], len(background_samples)))
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
@@ -482,7 +530,10 @@ class AudioProcessor(object):
else:
input_dict[self.foreground_volume_placeholder_] = 1
# Run the graph to produce the output audio.
- data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
+ summary, data_tensor = sess.run(
+ [self.merged_summaries_, self.output_], feed_dict=input_dict)
+ self.summary_writer_.add_summary(summary)
+ data[i - offset, :] = data_tensor.flatten()
label_index = self.word_to_index[sample['label']]
labels[i - offset] = label_index
return data, labels
diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py
index 13f294d39d..2e551be9a2 100644
--- a/tensorflow/examples/speech_commands/input_data_test.py
+++ b/tensorflow/examples/speech_commands/input_data_test.py
@@ -25,6 +25,7 @@ import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands import input_data
+from tensorflow.examples.speech_commands import models
from tensorflow.python.platform import test
@@ -32,7 +33,7 @@ class InputDataTest(test.TestCase):
def _getWavData(self):
with self.test_session() as sess:
- sample_data = tf.zeros([1000, 2])
+ sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
return wav_data
@@ -57,9 +58,31 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
+ def _runGetDataTest(self, preprocess, window_length_ms):
+ tmp_dir = self.get_temp_dir()
+ wav_dir = os.path.join(tmp_dir, "wavs")
+ os.mkdir(wav_dir)
+ self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
+ background_dir = os.path.join(wav_dir, "_background_noise_")
+ os.mkdir(background_dir)
+ wav_data = self._getWavData()
+ for i in range(10):
+ file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
+ self._saveTestWavFile(file_path, wav_data)
+ model_settings = models.prepare_model_settings(
+ 4, 16000, 1000, window_length_ms, 20, 40, preprocess)
+ with self.test_session() as sess:
+ audio_processor = input_data.AudioProcessor(
+ "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
+ result_data, result_labels = audio_processor.get_data(
+ 10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
+ self.assertEqual(10, len(result_data))
+ self.assertEqual(10, len(result_labels))
+
def testPrepareWordsList(self):
words_list = ["a", "b"]
self.assertGreater(
@@ -76,8 +99,9 @@ class InputDataTest(test.TestCase):
def testPrepareDataIndex(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
- audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
- 10, 10, self._model_settings())
+ audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
+ ["a", "b"], 10, 10,
+ self._model_settings(), tmp_dir)
self.assertLess(0, audio_processor.set_size("training"))
self.assertTrue("training" in audio_processor.data_index)
self.assertTrue("validation" in audio_processor.data_index)
@@ -90,7 +114,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
- self._model_settings())
+ self._model_settings(), tmp_dir)
self.assertTrue("No .wavs found" in str(e.exception))
def testPrepareDataIndexMissing(self):
@@ -98,7 +122,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
- 10, self._model_settings())
+ 10, self._model_settings(), tmp_dir)
self.assertTrue("Expected to find" in str(e.exception))
def testPrepareBackgroundData(self):
@@ -110,8 +134,9 @@ class InputDataTest(test.TestCase):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
- audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
- 10, 10, self._model_settings())
+ audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
+ ["a", "b"], 10, 10,
+ self._model_settings(), tmp_dir)
self.assertEqual(10, len(audio_processor.background_data))
def testLoadWavFile(self):
@@ -148,44 +173,27 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
+ 10, 10, model_settings, tmp_dir)
self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
self.assertIsNotNone(audio_processor.background_data_placeholder_)
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
- self.assertIsNotNone(audio_processor.mfcc_)
+ self.assertIsNotNone(audio_processor.output_)
- def testGetData(self):
- tmp_dir = self.get_temp_dir()
- wav_dir = os.path.join(tmp_dir, "wavs")
- os.mkdir(wav_dir)
- self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
- background_dir = os.path.join(wav_dir, "_background_noise_")
- os.mkdir(background_dir)
- wav_data = self._getWavData()
- for i in range(10):
- file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
- self._saveTestWavFile(file_path, wav_data)
- model_settings = {
- "desired_samples": 160,
- "fingerprint_size": 40,
- "label_count": 4,
- "window_size_samples": 100,
- "window_stride_samples": 100,
- "dct_coefficient_count": 40,
- }
- audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
- with self.test_session() as sess:
- result_data, result_labels = audio_processor.get_data(
- 10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
- self.assertEqual(10, len(result_data))
- self.assertEqual(10, len(result_labels))
+ def testGetDataAverage(self):
+ self._runGetDataTest("average", 10)
+
+ def testGetDataAverageLongWindow(self):
+ self._runGetDataTest("average", 30)
+
+ def testGetDataMfcc(self):
+ self._runGetDataTest("mfcc", 30)
def testGetUnprocessedData(self):
tmp_dir = self.get_temp_dir()
@@ -198,10 +206,11 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
+ 10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_unprocessed_data(
10, model_settings, "training")
self.assertEqual(10, len(result_data))
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index ab611f414a..65ae3b1511 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -24,9 +24,21 @@ import math
import tensorflow as tf
+def _next_power_of_two(x):
+ """Calculates the smallest enclosing power of two for an input.
+
+ Args:
+ x: Positive float or integer number.
+
+ Returns:
+ Next largest power of two integer.
+ """
+ return 1 if x == 0 else 2**(int(x) - 1).bit_length()
+
+
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
- window_size_ms, window_stride_ms,
- dct_coefficient_count):
+ window_size_ms, window_stride_ms, feature_bin_count,
+ preprocess):
"""Calculates common settings needed for all models.
Args:
@@ -35,10 +47,14 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
clip_duration_ms: Length of each audio clip to be analyzed.
window_size_ms: Duration of frequency analysis window.
window_stride_ms: How far to move in time between frequency windows.
- dct_coefficient_count: Number of frequency bins to use for analysis.
+ feature_bin_count: Number of frequency bins to use for analysis.
+ preprocess: How the spectrogram is processed to produce features.
Returns:
Dictionary containing common settings.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
@@ -48,16 +64,28 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
- fingerprint_size = dct_coefficient_count * spectrogram_length
+ if preprocess == 'average':
+ fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
+ average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
+ fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
+ elif preprocess == 'mfcc':
+ average_window_width = -1
+ fingerprint_width = feature_bin_count
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (preprocess))
+ fingerprint_size = fingerprint_width * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
- 'dct_coefficient_count': dct_coefficient_count,
+ 'fingerprint_width': fingerprint_width,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
+ 'preprocess': preprocess,
+ 'average_window_width': average_window_width,
}
@@ -106,10 +134,14 @@ def create_model(fingerprint_input, model_settings, model_architecture,
elif model_architecture == 'low_latency_svdf':
return create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings)
+ elif model_architecture == 'tiny_conv':
+ return create_tiny_conv_model(fingerprint_input, model_settings,
+ is_training)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
- ' "low_latency_conv, or "low_latency_svdf"')
+ ' "low_latency_conv, "low_latency_svdf",' +
+ ' or "tiny_conv"')
def load_variables_from_checkpoint(sess, start_checkpoint):
@@ -152,9 +184,12 @@ def create_single_fc_model(fingerprint_input, model_settings, is_training):
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
- weights = tf.Variable(
- tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
- bias = tf.Variable(tf.zeros([label_count]))
+ weights = tf.get_variable(
+ name='weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.001),
+ shape=[fingerprint_size, label_count])
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[label_count])
logits = tf.matmul(fingerprint_input, weights) + bias
if is_training:
return logits, dropout_prob
@@ -212,18 +247,21 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 20
first_filter_count = 64
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
@@ -235,14 +273,17 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
second_filter_width = 4
second_filter_height = 10
second_filter_count = 64
- second_weights = tf.Variable(
- tf.truncated_normal(
- [
- second_filter_height, second_filter_width, first_filter_count,
- second_filter_count
- ],
- stddev=0.01))
- second_bias = tf.Variable(tf.zeros([second_filter_count]))
+ second_weights = tf.get_variable(
+ name='second_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[
+ second_filter_height, second_filter_width, first_filter_count,
+ second_filter_count
+ ])
+ second_bias = tf.get_variable(
+ name='second_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_filter_count])
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
'SAME') + second_bias
second_relu = tf.nn.relu(second_conv)
@@ -259,10 +300,14 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
flattened_second_conv = tf.reshape(second_dropout,
[-1, second_conv_element_count])
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_conv_element_count, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer,
+ shape=[second_conv_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -318,7 +363,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
@@ -327,11 +372,14 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
first_filter_count = 186
first_filter_stride_x = 1
first_filter_stride_y = 1
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
1, first_filter_stride_y, first_filter_stride_x, 1
], 'VALID') + first_bias
@@ -351,30 +399,42 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
flattened_first_conv = tf.reshape(first_dropout,
[-1, first_conv_element_count])
first_fc_output_channels = 128
- first_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_conv_element_count, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_conv_element_count, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 128
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -422,7 +482,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
The node is expected to produce a 2D Tensor of shape:
- [batch, model_settings['dct_coefficient_count'] *
+ [batch, model_settings['fingerprint_width'] *
model_settings['spectrogram_length']]
with the features corresponding to the same time slot arranged contiguously,
and the oldest slot at index [:, 0], and newest at [:, -1].
@@ -440,7 +500,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
# Validation.
@@ -462,8 +522,11 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
num_filters = rank * num_units
# Create the runtime memory: [num_filters, batch, input_time_size]
batch = 1
- memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
- trainable=False, name='runtime-memory')
+ memory = tf.get_variable(
+ initializer=tf.zeros_initializer,
+ shape=[num_filters, batch, input_time_size],
+ trainable=False,
+ name='runtime-memory')
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
@@ -483,8 +546,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
# Create the frequency filters.
- weights_frequency = tf.Variable(
- tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
+ weights_frequency = tf.get_variable(
+ name='weights_frequency',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[input_frequency_size, num_filters])
# Expand to add input channels dimensions.
# weights_frequency: [input_frequency_size, 1, num_filters]
weights_frequency = tf.expand_dims(weights_frequency, 1)
@@ -506,8 +571,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
activations_time = new_memory
# Create the time filters.
- weights_time = tf.Variable(
- tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
+ weights_time = tf.get_variable(
+ name='weights_time',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_filters, input_time_size])
# Apply the time filter on the outputs of the feature filters.
# weights_time: [num_filters, input_time_size, 1]
# outputs: [num_filters, batch, 1]
@@ -524,7 +591,8 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
units_output = tf.transpose(units_output)
# Appy bias.
- bias = tf.Variable(tf.zeros([num_units]))
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[num_units])
first_bias = tf.nn.bias_add(units_output, bias)
# Relu.
@@ -536,31 +604,135 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
first_dropout = first_relu
first_fc_output_channels = 256
- first_fc_weights = tf.Variable(
- tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_units, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 256
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc
+
+
+def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
+ """Builds a convolutional model aimed at microcontrollers.
+
+ Devices like DSPs and microcontrollers can have very small amounts of
+ memory and limited processing power. This model is designed to use less
+ than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
+
+ Here's the layout of the graph:
+
+ (fingerprint_input)
+ v
+ [Conv2D]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+ [Relu]
+ v
+ [MatMul]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+
+ This doesn't produce particularly accurate results, but it's designed to be
+ used as the first stage of a pipeline, running on a low-energy piece of
+ hardware that can always be on, and then wake higher-power chips when a
+ possible utterance has been found, so that more accurate analysis can be done.
+
+ During training, a dropout node is introduced after the relu, controlled by a
+ placeholder.
+
+ Args:
+ fingerprint_input: TensorFlow node that will output audio feature vectors.
+ model_settings: Dictionary of information about the model.
+ is_training: Whether the model is going to be used for training.
+
+ Returns:
+ TensorFlow node outputting logits results, and optionally a dropout
+ placeholder.
+ """
+ if is_training:
+ dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+ input_frequency_size = model_settings['fingerprint_width']
+ input_time_size = model_settings['spectrogram_length']
+ fingerprint_4d = tf.reshape(fingerprint_input,
+ [-1, input_time_size, input_frequency_size, 1])
+ first_filter_width = 8
+ first_filter_height = 10
+ first_filter_count = 8
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
+ first_conv_stride_x = 2
+ first_conv_stride_y = 2
+ first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
+ [1, first_conv_stride_y, first_conv_stride_x, 1],
+ 'SAME') + first_bias
+ first_relu = tf.nn.relu(first_conv)
+ if is_training:
+ first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+ else:
+ first_dropout = first_relu
+ first_dropout_shape = first_dropout.get_shape()
+ first_dropout_output_width = first_dropout_shape[2]
+ first_dropout_output_height = first_dropout_shape[1]
+ first_dropout_element_count = int(
+ first_dropout_output_width * first_dropout_output_height *
+ first_filter_count)
+ flattened_first_dropout = tf.reshape(first_dropout,
+ [-1, first_dropout_element_count])
+ label_count = model_settings['label_count']
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_dropout_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
+ final_fc = (
+ tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
+ if is_training:
+ return final_fc, dropout_prob
+ else:
+ return final_fc
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 80c795367f..0c373967ed 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -26,12 +26,29 @@ from tensorflow.python.platform import test
class ModelsTest(test.TestCase):
+ def _modelSettings(self):
+ return models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc")
+
def testPrepareModelSettings(self):
self.assertIsNotNone(
- models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
+ models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc"))
def testCreateModelConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
@@ -42,7 +59,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelConvInference(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
@@ -51,7 +68,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
def testCreateModelLowLatencyConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -62,7 +79,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelFullyConnectedTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -73,7 +90,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelBadArchitecture(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
@@ -81,6 +98,17 @@ class ModelsTest(test.TestCase):
"bad_architecture", True)
self.assertTrue("not recognized" in str(e.exception))
+ def testCreateModelTinyConvTraining(self):
+ model_settings = self._modelSettings()
+ with self.test_session() as sess:
+ fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
+ logits, dropout_prob = models.create_model(
+ fingerprint_input, model_settings, "tiny_conv", True)
+ self.assertIsNotNone(logits)
+ self.assertIsNotNone(dropout_prob)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index fc28eb0631..eca34f8812 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -98,12 +98,12 @@ def main(_):
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
- FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
+ FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
audio_processor = input_data.AudioProcessor(
- FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
- FLAGS.unknown_percentage,
+ FLAGS.data_url, FLAGS.data_dir,
+ FLAGS.silence_percentage, FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
- FLAGS.testing_percentage, model_settings)
+ FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
@@ -122,8 +122,25 @@ def main(_):
'lists, but are %d and %d long instead' % (len(training_steps_list),
len(learning_rates_list)))
- fingerprint_input = tf.placeholder(
+ input_placeholder = tf.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input')
+ if FLAGS.quantize:
+ # TODO(petewarden): These values have been derived from the observed ranges
+ # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
+ # they may need to be updated.
+ if FLAGS.preprocess == 'average':
+ fingerprint_min = 0.0
+ fingerprint_max = 2048.0
+ elif FLAGS.preprocess == 'mfcc':
+ fingerprint_min = -247.0
+ fingerprint_max = 30.0
+ else:
+ raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (FLAGS.preprocess))
+ fingerprint_input = tf.fake_quant_with_min_max_args(
+ input_placeholder, fingerprint_min, fingerprint_max)
+ else:
+ fingerprint_input = input_placeholder
logits, dropout_prob = models.create_model(
fingerprint_input,
@@ -146,7 +163,8 @@ def main(_):
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
- tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ if FLAGS.quantize:
+ tf.contrib.quantize.create_training_graph(quant_delay=0)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
tf.float32, [], name='learning_rate_input')
@@ -157,7 +175,9 @@ def main(_):
confusion_matrix = tf.confusion_matrix(
ground_truth_input, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- tf.summary.scalar('accuracy', evaluation_step)
+ with tf.get_default_graph().name_scope('eval'):
+ tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ tf.summary.scalar('accuracy', evaluation_step)
global_step = tf.train.get_or_create_global_step()
increment_global_step = tf.assign(global_step, global_step + 1)
@@ -165,7 +185,7 @@ def main(_):
saver = tf.train.Saver(tf.global_variables())
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
- merged_summaries = tf.summary.merge_all()
+ merged_summaries = tf.summary.merge_all(scope='eval')
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
@@ -207,8 +227,11 @@ def main(_):
# Run the graph with this batch of training data.
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
[
- merged_summaries, evaluation_step, cross_entropy_mean, train_step,
- increment_global_step
+ merged_summaries,
+ evaluation_step,
+ cross_entropy_mean,
+ train_step,
+ increment_global_step,
],
feed_dict={
fingerprint_input: train_fingerprints,
@@ -364,10 +387,11 @@ if __name__ == '__main__':
default=10.0,
help='How far to move in time between spectogram timeslices.',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--how_many_training_steps',
type=str,
@@ -423,6 +447,16 @@ if __name__ == '__main__':
type=bool,
default=False,
help='Whether to check for invalid numbers during processing')
+ parser.add_argument(
+ '--quantize',
+ type=bool,
+ default=False,
+ help='Whether to train the model for eight-bit deployment')
+ parser.add_argument(
+ '--preprocess',
+ type=str,
+ default='mfcc',
+ help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)