aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/train.py')
-rw-r--r--tensorflow/examples/speech_commands/train.py58
1 files changed, 46 insertions, 12 deletions
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index fc28eb0631..eca34f8812 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -98,12 +98,12 @@ def main(_):
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
- FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
+ FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
audio_processor = input_data.AudioProcessor(
- FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
- FLAGS.unknown_percentage,
+ FLAGS.data_url, FLAGS.data_dir,
+ FLAGS.silence_percentage, FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
- FLAGS.testing_percentage, model_settings)
+ FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
@@ -122,8 +122,25 @@ def main(_):
'lists, but are %d and %d long instead' % (len(training_steps_list),
len(learning_rates_list)))
- fingerprint_input = tf.placeholder(
+ input_placeholder = tf.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input')
+ if FLAGS.quantize:
+ # TODO(petewarden): These values have been derived from the observed ranges
+ # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
+ # they may need to be updated.
+ if FLAGS.preprocess == 'average':
+ fingerprint_min = 0.0
+ fingerprint_max = 2048.0
+ elif FLAGS.preprocess == 'mfcc':
+ fingerprint_min = -247.0
+ fingerprint_max = 30.0
+ else:
+ raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (FLAGS.preprocess))
+ fingerprint_input = tf.fake_quant_with_min_max_args(
+ input_placeholder, fingerprint_min, fingerprint_max)
+ else:
+ fingerprint_input = input_placeholder
logits, dropout_prob = models.create_model(
fingerprint_input,
@@ -146,7 +163,8 @@ def main(_):
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
- tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ if FLAGS.quantize:
+ tf.contrib.quantize.create_training_graph(quant_delay=0)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
tf.float32, [], name='learning_rate_input')
@@ -157,7 +175,9 @@ def main(_):
confusion_matrix = tf.confusion_matrix(
ground_truth_input, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- tf.summary.scalar('accuracy', evaluation_step)
+ with tf.get_default_graph().name_scope('eval'):
+ tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ tf.summary.scalar('accuracy', evaluation_step)
global_step = tf.train.get_or_create_global_step()
increment_global_step = tf.assign(global_step, global_step + 1)
@@ -165,7 +185,7 @@ def main(_):
saver = tf.train.Saver(tf.global_variables())
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
- merged_summaries = tf.summary.merge_all()
+ merged_summaries = tf.summary.merge_all(scope='eval')
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
@@ -207,8 +227,11 @@ def main(_):
# Run the graph with this batch of training data.
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
[
- merged_summaries, evaluation_step, cross_entropy_mean, train_step,
- increment_global_step
+ merged_summaries,
+ evaluation_step,
+ cross_entropy_mean,
+ train_step,
+ increment_global_step,
],
feed_dict={
fingerprint_input: train_fingerprints,
@@ -364,10 +387,11 @@ if __name__ == '__main__':
default=10.0,
help='How far to move in time between spectogram timeslices.',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--how_many_training_steps',
type=str,
@@ -423,6 +447,16 @@ if __name__ == '__main__':
type=bool,
default=False,
help='Whether to check for invalid numbers during processing')
+ parser.add_argument(
+ '--quantize',
+ type=bool,
+ default=False,
+ help='Whether to train the model for eight-bit deployment')
+ parser.add_argument(
+ '--preprocess',
+ type=str,
+ default='mfcc',
+ help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)