aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/freeze_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/freeze_test.py')
-rw-r--r--tensorflow/examples/speech_commands/freeze_test.py54
1 files changed, 51 insertions, 3 deletions
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index 97c6eac675..c8de6c2152 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -24,14 +24,62 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
- def testCreateInferenceGraph(self):
+ def testCreateInferenceGraphWithMfcc(self):
with self.test_session() as sess:
- freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
- 40, 'conv')
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='mfcc')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(1, ops.count('Mfcc'))
+
+ def testCreateInferenceGraphWithoutMfcc(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
+
+ def testFeatureBinCount(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=80,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
if __name__ == '__main__':