diff options
Diffstat (limited to 'tensorflow/examples/speech_commands/input_data.py')
-rw-r--r-- | tensorflow/examples/speech_commands/input_data.py | 135 |
1 files changed, 93 insertions, 42 deletions
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py index 63dd18457f..30f2cfa9fe 100644 --- a/tensorflow/examples/speech_commands/input_data.py +++ b/tensorflow/examples/speech_commands/input_data.py @@ -153,14 +153,14 @@ class AudioProcessor(object): def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, wanted_words, validation_percentage, testing_percentage, - model_settings): + model_settings, summaries_dir): self.data_dir = data_dir self.maybe_download_and_extract_dataset(data_url, data_dir) self.prepare_data_index(silence_percentage, unknown_percentage, wanted_words, validation_percentage, testing_percentage) self.prepare_background_data() - self.prepare_processing_graph(model_settings) + self.prepare_processing_graph(model_settings, summaries_dir) def maybe_download_and_extract_dataset(self, data_url, dest_directory): """Download and extract data set tar file. @@ -325,7 +325,7 @@ class AudioProcessor(object): if not self.background_data: raise Exception('No background wav files were found in ' + search_path) - def prepare_processing_graph(self, model_settings): + def prepare_processing_graph(self, model_settings, summaries_dir): """Builds a TensorFlow graph to apply the input distortions. Creates a graph that loads a WAVE file, decodes it, scales the volume, @@ -341,48 +341,88 @@ class AudioProcessor(object): - time_shift_offset_placeholder_: How much to move the clip in time. - background_data_placeholder_: PCM sample data for background noise. - background_volume_placeholder_: Loudness of mixed-in background. - - mfcc_: Output 2D fingerprint of processed audio. + - output_: Output 2D fingerprint of processed audio. Args: model_settings: Information about the current model being trained. + summaries_dir: Path to save training summary information to. + + Raises: + ValueError: If the preprocessing mode isn't recognized. """ - desired_samples = model_settings['desired_samples'] - self.wav_filename_placeholder_ = tf.placeholder(tf.string, []) - wav_loader = io_ops.read_file(self.wav_filename_placeholder_) - wav_decoder = contrib_audio.decode_wav( - wav_loader, desired_channels=1, desired_samples=desired_samples) - # Allow the audio sample's volume to be adjusted. - self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, []) - scaled_foreground = tf.multiply(wav_decoder.audio, - self.foreground_volume_placeholder_) - # Shift the sample's start position, and pad any gaps with zeros. - self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2]) - self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2]) - padded_foreground = tf.pad( - scaled_foreground, - self.time_shift_padding_placeholder_, - mode='CONSTANT') - sliced_foreground = tf.slice(padded_foreground, - self.time_shift_offset_placeholder_, - [desired_samples, -1]) - # Mix in background noise. - self.background_data_placeholder_ = tf.placeholder(tf.float32, - [desired_samples, 1]) - self.background_volume_placeholder_ = tf.placeholder(tf.float32, []) - background_mul = tf.multiply(self.background_data_placeholder_, - self.background_volume_placeholder_) - background_add = tf.add(background_mul, sliced_foreground) - background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) - # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. - spectrogram = contrib_audio.audio_spectrogram( - background_clamp, - window_size=model_settings['window_size_samples'], - stride=model_settings['window_stride_samples'], - magnitude_squared=True) - self.mfcc_ = contrib_audio.mfcc( - spectrogram, - wav_decoder.sample_rate, - dct_coefficient_count=model_settings['dct_coefficient_count']) + with tf.get_default_graph().name_scope('data'): + desired_samples = model_settings['desired_samples'] + self.wav_filename_placeholder_ = tf.placeholder( + tf.string, [], name='wav_filename') + wav_loader = io_ops.read_file(self.wav_filename_placeholder_) + wav_decoder = contrib_audio.decode_wav( + wav_loader, desired_channels=1, desired_samples=desired_samples) + # Allow the audio sample's volume to be adjusted. + self.foreground_volume_placeholder_ = tf.placeholder( + tf.float32, [], name='foreground_volume') + scaled_foreground = tf.multiply(wav_decoder.audio, + self.foreground_volume_placeholder_) + # Shift the sample's start position, and pad any gaps with zeros. + self.time_shift_padding_placeholder_ = tf.placeholder( + tf.int32, [2, 2], name='time_shift_padding') + self.time_shift_offset_placeholder_ = tf.placeholder( + tf.int32, [2], name='time_shift_offset') + padded_foreground = tf.pad( + scaled_foreground, + self.time_shift_padding_placeholder_, + mode='CONSTANT') + sliced_foreground = tf.slice(padded_foreground, + self.time_shift_offset_placeholder_, + [desired_samples, -1]) + # Mix in background noise. + self.background_data_placeholder_ = tf.placeholder( + tf.float32, [desired_samples, 1], name='background_data') + self.background_volume_placeholder_ = tf.placeholder( + tf.float32, [], name='background_volume') + background_mul = tf.multiply(self.background_data_placeholder_, + self.background_volume_placeholder_) + background_add = tf.add(background_mul, sliced_foreground) + background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) + # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. + spectrogram = contrib_audio.audio_spectrogram( + background_clamp, + window_size=model_settings['window_size_samples'], + stride=model_settings['window_stride_samples'], + magnitude_squared=True) + tf.summary.image( + 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1) + # The number of buckets in each FFT row in the spectrogram will depend on + # how many input samples there are in each window. This can be quite + # large, with a 160 sample window producing 127 buckets for example. We + # don't need this level of detail for classification, so we often want to + # shrink them down to produce a smaller result. That's what this section + # implements. One method is to use average pooling to merge adjacent + # buckets, but a more sophisticated approach is to apply the MFCC + # algorithm to shrink the representation. + if model_settings['preprocess'] == 'average': + self.output_ = tf.nn.pool( + tf.expand_dims(spectrogram, -1), + window_shape=[1, model_settings['average_window_width']], + strides=[1, model_settings['average_window_width']], + pooling_type='AVG', + padding='SAME') + tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1) + elif model_settings['preprocess'] == 'mfcc': + self.output_ = contrib_audio.mfcc( + spectrogram, + wav_decoder.sample_rate, + dct_coefficient_count=model_settings['fingerprint_width']) + tf.summary.image( + 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1) + else: + raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or' + ' "average")' % (model_settings['preprocess'])) + + # Merge all the summaries and write them out to /tmp/retrain_logs (by + # default) + self.merged_summaries_ = tf.summary.merge_all(scope='data') + self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data', + tf.get_default_graph()) def set_size(self, mode): """Calculates the number of samples in the dataset partition. @@ -418,6 +458,9 @@ class AudioProcessor(object): Returns: List of sample data for the transformed samples, and list of label indexes + + Raises: + ValueError: If background samples are too short. """ # Pick one of the partitions to choose samples from. candidates = self.data_index[mode] @@ -460,6 +503,11 @@ class AudioProcessor(object): if use_background or sample['label'] == SILENCE_LABEL: background_index = np.random.randint(len(self.background_data)) background_samples = self.background_data[background_index] + if len(background_samples) <= model_settings['desired_samples']: + raise ValueError( + 'Background sample is too short! Need more than %d' + ' samples but only %d were found' % + (model_settings['desired_samples'], len(background_samples))) background_offset = np.random.randint( 0, len(background_samples) - model_settings['desired_samples']) background_clipped = background_samples[background_offset:( @@ -482,7 +530,10 @@ class AudioProcessor(object): else: input_dict[self.foreground_volume_placeholder_] = 1 # Run the graph to produce the output audio. - data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() + summary, data_tensor = sess.run( + [self.merged_summaries_, self.output_], feed_dict=input_dict) + self.summary_writer_.add_summary(summary) + data[i - offset, :] = data_tensor.flatten() label_index = self.word_to_index[sample['label']] labels[i - offset] = label_index return data, labels |