aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-21 09:35:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 09:42:07 -0800
commit982c183dee45efe27f02702b53d304cdd0e32ed4 (patch)
treef4b175508171adbad5a75a25de28a2f7b031686b /tensorflow/contrib/lite/models
parent04ec855825289caa5ef76a2cb370bc351f10bd74 (diff)
Internal Change
PiperOrigin-RevId: 186472818
Diffstat (limited to 'tensorflow/contrib/lite/models')
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc44
1 files changed, 30 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index daa8c3100b..a354179a94 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -97,7 +97,12 @@ bool ConvertCsvData(const string& model_name, const string& in_name,
return true;
}
-TEST(SpeechTest, HotwordOkGoogleRank1Test) {
+class SpeechTest : public ::testing::TestWithParam<int> {
+ protected:
+ int GetMaxInvocations() { return GetParam(); }
+};
+
+TEST_P(SpeechTest, HotwordOkGoogleRank1Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
@@ -105,11 +110,11 @@ TEST(SpeechTest, HotwordOkGoogleRank1Test) {
/*output_tensor=*/"18", /*persistent_tensors=*/"4",
/*sequence_size=*/40, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
-TEST(SpeechTest, HotwordOkGoogleRank2Test) {
+TEST_P(SpeechTest, HotwordOkGoogleRank2Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
@@ -117,11 +122,11 @@ TEST(SpeechTest, HotwordOkGoogleRank2Test) {
/*output_tensor=*/"18", /*persistent_tensors=*/"1",
/*sequence_size=*/40, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
-TEST(SpeechTest, SpeakerIdOkGoogleTest) {
+TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
@@ -130,11 +135,11 @@ TEST(SpeechTest, SpeakerIdOkGoogleTest) {
/*persistent_tensors=*/"19,20,40,41,61,62",
/*sequence_size=*/80, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
-TEST(SpeechTest, AsrAmTest) {
+TEST_P(SpeechTest, AsrAmTest) {
std::stringstream os;
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
@@ -143,7 +148,7 @@ TEST(SpeechTest, AsrAmTest) {
/*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
@@ -151,15 +156,16 @@ TEST(SpeechTest, AsrAmTest) {
// through the interpreter and stored the sum of all the output, which was them
// compared for correctness. In this test we are comparing all the intermediate
// results.
-TEST(SpeechTest, AsrLmTest) {
+TEST_P(SpeechTest, AsrLmTest) {
std::ifstream in_file;
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
- ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver))
+ ASSERT_TRUE(
+ testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
-TEST(SpeechTest, EndpointerTest) {
+TEST_P(SpeechTest, EndpointerTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
@@ -168,11 +174,11 @@ TEST(SpeechTest, EndpointerTest) {
/*persistent_tensors=*/"28,29,49,50",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
-TEST(SpeechTest, TtsTest) {
+TEST_P(SpeechTest, TtsTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
@@ -181,9 +187,19 @@ TEST(SpeechTest, TtsTest) {
/*persistent_tensors=*/"25,26,46,47,67,68,73",
/*sequence_size=*/334, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
- ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
<< test_driver.GetErrorMessage();
}
+// Define two instantiations. The "ShortTests" instantiations is used when
+// running the tests on Android, in order to prevent timeouts (It takes about
+// 200s just to bring up the Android emulator.)
+static const int kAllInvocations = -1;
+static const int kFirstFewInvocations = 10;
+INSTANTIATE_TEST_CASE_P(LongTests, SpeechTest,
+ ::testing::Values(kAllInvocations));
+INSTANTIATE_TEST_CASE_P(ShortTests, SpeechTest,
+ ::testing::Values(kFirstFewInvocations));
+
} // namespace
} // namespace tflite