diff options
Diffstat (limited to 'tensorflow/examples/speech_commands/freeze_test.py')
-rw-r--r-- | tensorflow/examples/speech_commands/freeze_test.py | 54 |
1 files changed, 51 insertions, 3 deletions
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py index 97c6eac675..c8de6c2152 100644 --- a/tensorflow/examples/speech_commands/freeze_test.py +++ b/tensorflow/examples/speech_commands/freeze_test.py @@ -24,14 +24,62 @@ from tensorflow.python.platform import test class FreezeTest(test.TestCase): - def testCreateInferenceGraph(self): + def testCreateInferenceGraphWithMfcc(self): with self.test_session() as sess: - freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0, - 40, 'conv') + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=40, + model_architecture='conv', + preprocess='mfcc') self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) self.assertIsNotNone( sess.graph.get_tensor_by_name('decoded_sample_data:0')) self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(1, ops.count('Mfcc')) + + def testCreateInferenceGraphWithoutMfcc(self): + with self.test_session() as sess: + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=40, + model_architecture='conv', + preprocess='average') + self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) + self.assertIsNotNone( + sess.graph.get_tensor_by_name('decoded_sample_data:0')) + self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(0, ops.count('Mfcc')) + + def testFeatureBinCount(self): + with self.test_session() as sess: + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=80, + model_architecture='conv', + preprocess='average') + self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) + self.assertIsNotNone( + sess.graph.get_tensor_by_name('decoded_sample_data:0')) + self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(0, ops.count('Mfcc')) if __name__ == '__main__': |