aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/label_wav_dir.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/label_wav_dir.py')
-rw-r--r--tensorflow/examples/speech_commands/label_wav_dir.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/tensorflow/examples/speech_commands/label_wav_dir.py b/tensorflow/examples/speech_commands/label_wav_dir.py
new file mode 100644
index 0000000000..a34db512dd
--- /dev/null
+++ b/tensorflow/examples/speech_commands/label_wav_dir.py
@@ -0,0 +1,136 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Runs a trained audio graph against WAVE files and reports the results.
+
+The model, labels and .wav files specified in the arguments will be loaded, and
+then the predictions from running the model against the audio data will be
+printed to the console. This is a useful script for sanity checking trained
+models, and as an example of how to use an audio model from Python.
+
+Here's an example of running it:
+
+python tensorflow/examples/speech_commands/label_wav_dir.py \
+--graph=/tmp/my_frozen_graph.pb \
+--labels=/tmp/speech_commands_train/conv_labels.txt \
+--wav_dir=/tmp/speech_dataset/left
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import glob
+import sys
+
+import tensorflow as tf
+
+# pylint: disable=unused-import
+from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
+# pylint: enable=unused-import
+
+FLAGS = None
+
+
+def load_graph(filename):
+ """Unpersists graph from file as default graph."""
+ with tf.gfile.FastGFile(filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+
+def load_labels(filename):
+ """Read in labels, one label per line."""
+ return [line.rstrip() for line in tf.gfile.GFile(filename)]
+
+
+def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
+ num_top_predictions):
+ """Runs the audio data through the graph and prints predictions."""
+ with tf.Session() as sess:
+ # Feed the audio data as input to the graph.
+ # predictions will contain a two-dimensional array, where one
+ # dimension represents the input image count, and the other has
+ # predictions per class
+ for wav_path in glob.glob(wav_dir + '/*.wav'):
+ if not wav_path or not tf.gfile.Exists(wav_path):
+ tf.logging.fatal('Audio file does not exist %s', wav_path)
+
+ with open(wav_path, 'rb') as wav_file:
+ wav_data = wav_file.read()
+
+ softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
+ predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
+
+ # Sort to show labels in order of confidence
+ print('\n%s' % (wav_path.split('/')[-1]))
+ top_k = predictions.argsort()[-num_top_predictions:][::-1]
+ for node_id in top_k:
+ human_string = labels[node_id]
+ score = predictions[node_id]
+ print('%s (score = %.5f)' % (human_string, score))
+
+ return 0
+
+
+def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
+ """Loads the model and labels, and runs the inference to print predictions."""
+ if not labels or not tf.gfile.Exists(labels):
+ tf.logging.fatal('Labels file does not exist %s', labels)
+
+ if not graph or not tf.gfile.Exists(graph):
+ tf.logging.fatal('Graph file does not exist %s', graph)
+
+ labels_list = load_labels(labels)
+
+ # load graph, which is stored in the default session
+ load_graph(graph)
+
+ run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels)
+
+
+def main(_):
+ """Entry point for script, converts flags to arguments."""
+ label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
+ FLAGS.output_name, FLAGS.how_many_labels)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--wav_dir', type=str, default='', help='Audio file to be identified.')
+ parser.add_argument(
+ '--graph', type=str, default='', help='Model to use for identification.')
+ parser.add_argument(
+ '--labels', type=str, default='', help='Path to file containing labels.')
+ parser.add_argument(
+ '--input_name',
+ type=str,
+ default='wav_data:0',
+ help='Name of WAVE data input node in model.')
+ parser.add_argument(
+ '--output_name',
+ type=str,
+ default='labels_softmax:0',
+ help='Name of node outputting a prediction in the model.')
+ parser.add_argument(
+ '--how_many_labels',
+ type=int,
+ default=3,
+ help='Number of results to show.')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)