aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/input_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/input_data.py')
-rw-r--r--tensorflow/examples/speech_commands/input_data.py135
1 files changed, 93 insertions, 42 deletions
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py
index 63dd18457f..30f2cfa9fe 100644
--- a/tensorflow/examples/speech_commands/input_data.py
+++ b/tensorflow/examples/speech_commands/input_data.py
@@ -153,14 +153,14 @@ class AudioProcessor(object):
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
- model_settings):
+ model_settings, summaries_dir):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
- self.prepare_processing_graph(model_settings)
+ self.prepare_processing_graph(model_settings, summaries_dir)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
"""Download and extract data set tar file.
@@ -325,7 +325,7 @@ class AudioProcessor(object):
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
- def prepare_processing_graph(self, model_settings):
+ def prepare_processing_graph(self, model_settings, summaries_dir):
"""Builds a TensorFlow graph to apply the input distortions.
Creates a graph that loads a WAVE file, decodes it, scales the volume,
@@ -341,48 +341,88 @@ class AudioProcessor(object):
- time_shift_offset_placeholder_: How much to move the clip in time.
- background_data_placeholder_: PCM sample data for background noise.
- background_volume_placeholder_: Loudness of mixed-in background.
- - mfcc_: Output 2D fingerprint of processed audio.
+ - output_: Output 2D fingerprint of processed audio.
Args:
model_settings: Information about the current model being trained.
+ summaries_dir: Path to save training summary information to.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
- desired_samples = model_settings['desired_samples']
- self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
- wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
- wav_decoder = contrib_audio.decode_wav(
- wav_loader, desired_channels=1, desired_samples=desired_samples)
- # Allow the audio sample's volume to be adjusted.
- self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
- scaled_foreground = tf.multiply(wav_decoder.audio,
- self.foreground_volume_placeholder_)
- # Shift the sample's start position, and pad any gaps with zeros.
- self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
- self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
- padded_foreground = tf.pad(
- scaled_foreground,
- self.time_shift_padding_placeholder_,
- mode='CONSTANT')
- sliced_foreground = tf.slice(padded_foreground,
- self.time_shift_offset_placeholder_,
- [desired_samples, -1])
- # Mix in background noise.
- self.background_data_placeholder_ = tf.placeholder(tf.float32,
- [desired_samples, 1])
- self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
- background_mul = tf.multiply(self.background_data_placeholder_,
- self.background_volume_placeholder_)
- background_add = tf.add(background_mul, sliced_foreground)
- background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
- # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
- spectrogram = contrib_audio.audio_spectrogram(
- background_clamp,
- window_size=model_settings['window_size_samples'],
- stride=model_settings['window_stride_samples'],
- magnitude_squared=True)
- self.mfcc_ = contrib_audio.mfcc(
- spectrogram,
- wav_decoder.sample_rate,
- dct_coefficient_count=model_settings['dct_coefficient_count'])
+ with tf.get_default_graph().name_scope('data'):
+ desired_samples = model_settings['desired_samples']
+ self.wav_filename_placeholder_ = tf.placeholder(
+ tf.string, [], name='wav_filename')
+ wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
+ wav_decoder = contrib_audio.decode_wav(
+ wav_loader, desired_channels=1, desired_samples=desired_samples)
+ # Allow the audio sample's volume to be adjusted.
+ self.foreground_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='foreground_volume')
+ scaled_foreground = tf.multiply(wav_decoder.audio,
+ self.foreground_volume_placeholder_)
+ # Shift the sample's start position, and pad any gaps with zeros.
+ self.time_shift_padding_placeholder_ = tf.placeholder(
+ tf.int32, [2, 2], name='time_shift_padding')
+ self.time_shift_offset_placeholder_ = tf.placeholder(
+ tf.int32, [2], name='time_shift_offset')
+ padded_foreground = tf.pad(
+ scaled_foreground,
+ self.time_shift_padding_placeholder_,
+ mode='CONSTANT')
+ sliced_foreground = tf.slice(padded_foreground,
+ self.time_shift_offset_placeholder_,
+ [desired_samples, -1])
+ # Mix in background noise.
+ self.background_data_placeholder_ = tf.placeholder(
+ tf.float32, [desired_samples, 1], name='background_data')
+ self.background_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='background_volume')
+ background_mul = tf.multiply(self.background_data_placeholder_,
+ self.background_volume_placeholder_)
+ background_add = tf.add(background_mul, sliced_foreground)
+ background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
+ # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
+ spectrogram = contrib_audio.audio_spectrogram(
+ background_clamp,
+ window_size=model_settings['window_size_samples'],
+ stride=model_settings['window_stride_samples'],
+ magnitude_squared=True)
+ tf.summary.image(
+ 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
+ # The number of buckets in each FFT row in the spectrogram will depend on
+ # how many input samples there are in each window. This can be quite
+ # large, with a 160 sample window producing 127 buckets for example. We
+ # don't need this level of detail for classification, so we often want to
+ # shrink them down to produce a smaller result. That's what this section
+ # implements. One method is to use average pooling to merge adjacent
+ # buckets, but a more sophisticated approach is to apply the MFCC
+ # algorithm to shrink the representation.
+ if model_settings['preprocess'] == 'average':
+ self.output_ = tf.nn.pool(
+ tf.expand_dims(spectrogram, -1),
+ window_shape=[1, model_settings['average_window_width']],
+ strides=[1, model_settings['average_window_width']],
+ pooling_type='AVG',
+ padding='SAME')
+ tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1)
+ elif model_settings['preprocess'] == 'mfcc':
+ self.output_ = contrib_audio.mfcc(
+ spectrogram,
+ wav_decoder.sample_rate,
+ dct_coefficient_count=model_settings['fingerprint_width'])
+ tf.summary.image(
+ 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (model_settings['preprocess']))
+
+ # Merge all the summaries and write them out to /tmp/retrain_logs (by
+ # default)
+ self.merged_summaries_ = tf.summary.merge_all(scope='data')
+ self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data',
+ tf.get_default_graph())
def set_size(self, mode):
"""Calculates the number of samples in the dataset partition.
@@ -418,6 +458,9 @@ class AudioProcessor(object):
Returns:
List of sample data for the transformed samples, and list of label indexes
+
+ Raises:
+ ValueError: If background samples are too short.
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
@@ -460,6 +503,11 @@ class AudioProcessor(object):
if use_background or sample['label'] == SILENCE_LABEL:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
+ if len(background_samples) <= model_settings['desired_samples']:
+ raise ValueError(
+ 'Background sample is too short! Need more than %d'
+ ' samples but only %d were found' %
+ (model_settings['desired_samples'], len(background_samples)))
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
@@ -482,7 +530,10 @@ class AudioProcessor(object):
else:
input_dict[self.foreground_volume_placeholder_] = 1
# Run the graph to produce the output audio.
- data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
+ summary, data_tensor = sess.run(
+ [self.merged_summaries_, self.output_], feed_dict=input_dict)
+ self.summary_writer_.add_summary(summary)
+ data[i - offset, :] = data_tensor.flatten()
label_index = self.word_to_index[sample['label']]
labels[i - offset] = label_index
return data, labels