aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/models_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/speech_commands/models_test.py')
-rw-r--r--tensorflow/examples/speech_commands/models_test.py40
1 files changed, 34 insertions, 6 deletions
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 80c795367f..0c373967ed 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -26,12 +26,29 @@ from tensorflow.python.platform import test
class ModelsTest(test.TestCase):
+ def _modelSettings(self):
+ return models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc")
+
def testPrepareModelSettings(self):
self.assertIsNotNone(
- models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
+ models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc"))
def testCreateModelConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
@@ -42,7 +59,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelConvInference(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
@@ -51,7 +68,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
def testCreateModelLowLatencyConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -62,7 +79,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelFullyConnectedTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -73,7 +90,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelBadArchitecture(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
@@ -81,6 +98,17 @@ class ModelsTest(test.TestCase):
"bad_architecture", True)
self.assertTrue("not recognized" in str(e.exception))
+ def testCreateModelTinyConvTraining(self):
+ model_settings = self._modelSettings()
+ with self.test_session() as sess:
+ fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
+ logits, dropout_prob = models.create_model(
+ fingerprint_input, model_settings, "tiny_conv", True)
+ self.assertIsNotNone(logits)
+ self.assertIsNotNone(dropout_prob)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
+
if __name__ == "__main__":
test.main()