aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-23 15:08:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 15:12:38 -0800
commit58d227d36aa17d038c70c94787f415b1cd64d982 (patch)
tree8f996478aa6b51345d4620c8d07f31c16da132bd /tensorflow/contrib/lite/models
parent68a9ee1d4e041d7690a949718b2651b035a6bfad (diff)
Switch models/ to use new test-specification format
PiperOrigin-RevId: 182999650
Diffstat (limited to 'tensorflow/contrib/lite/models')
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc189
-rw-r--r--tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec202
2 files changed, 391 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
new file mode 100644
index 0000000000..daa8c3100b
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech models (Hotword, SpeakerId) using TFLite Ops.
+
+#include <memory>
+#include <string>
+
+#include <fstream>
+
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include "tensorflow/contrib/lite/testing/split.h"
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+namespace tflite {
+namespace {
+
+const char kDataPath[] = "third_party/tensorflow/contrib/lite/models/testdata/";
+
+bool Init(const string& in_file_name, testing::TfLiteDriver* driver,
+ std::ifstream* in_file) {
+ driver->SetModelBaseDir(kDataPath);
+ in_file->open(string(kDataPath) + in_file_name, std::ifstream::in);
+ return in_file->is_open();
+}
+
+// Converts a set of test files provided by the speech team into a single
+// test_spec. Input CSV files are supposed to contain a number of sequences per
+// line. Each sequence maps to a single invocation of the interpreter and the
+// output tensor after all sequences have run is compared to the corresponding
+// line in the output CSV file.
+bool ConvertCsvData(const string& model_name, const string& in_name,
+ const string& out_name, const string& input_tensor,
+ const string& output_tensor,
+ const string& persistent_tensors, int sequence_size,
+ std::ostream* out) {
+ auto data_path = [](const string& s) { return string(kDataPath) + s; };
+
+ *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl;
+
+ *out << "init_state: \"" << persistent_tensors << "\"" << std::endl;
+
+ string in_file_name = data_path(in_name);
+ std::ifstream in_file(in_file_name);
+ if (!in_file.is_open()) {
+ std::cerr << "Failed to open " << in_file_name << std::endl;
+ return false;
+ }
+ string out_file_name = data_path(out_name);
+ std::ifstream out_file(out_file_name);
+ if (!out_file.is_open()) {
+ std::cerr << "Failed to open " << out_file_name << std::endl;
+ return false;
+ }
+
+ int invocation_count = 0;
+ string in_values;
+ while (std::getline(in_file, in_values, '\n')) {
+ std::vector<string> input = testing::Split<string>(in_values, ",");
+ int num_sequences = input.size() / sequence_size;
+
+ for (int j = 0; j < num_sequences; ++j) {
+ *out << "invoke {" << std::endl;
+ *out << " id: " << invocation_count << std::endl;
+ *out << " input: \"";
+ for (int k = 0; k < sequence_size; ++k) {
+ *out << input[k + j * sequence_size] << ",";
+ }
+ *out << "\"" << std::endl;
+
+ if (j == num_sequences - 1) {
+ string out_values;
+ if (!std::getline(out_file, out_values, '\n')) {
+ std::cerr << "Not enough lines in " << out_file_name << std::endl;
+ return false;
+ }
+ *out << " output: \"" << out_values << "\"" << std::endl;
+ }
+
+ *out << "}" << std::endl;
+ ++invocation_count;
+ }
+ }
+ return true;
+}
+
+TEST(SpeechTest, HotwordOkGoogleRank1Test) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"18", /*persistent_tensors=*/"4",
+ /*sequence_size=*/40, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, HotwordOkGoogleRank2Test) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
+ /*output_tensor=*/"18", /*persistent_tensors=*/"1",
+ /*sequence_size=*/40, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, SpeakerIdOkGoogleTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
+ "speech_speakerid_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"66",
+ /*persistent_tensors=*/"19,20,40,41,61,62",
+ /*sequence_size=*/80, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, AsrAmTest) {
+ std::stringstream os;
+ ASSERT_TRUE(
+ ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
+ "speech_asr_am_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"109",
+ /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
+ /*sequence_size=*/320, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+// The original version of speech_asr_lm_model_test.cc ran a few sequences
+// through the interpreter and stored the sum of all the output, which was them
+// compared for correctness. In this test we are comparing all the intermediate
+// results.
+TEST(SpeechTest, AsrLmTest) {
+ std::ifstream in_file;
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
+ ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, EndpointerTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
+ "speech_endpointer_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"58",
+ /*persistent_tensors=*/"28,29,49,50",
+ /*sequence_size=*/320, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, TtsTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
+ "speech_tts_model_in.csv",
+ "speech_tts_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"74",
+ /*persistent_tensors=*/"25,26,46,47,67,68,73",
+ /*sequence_size=*/334, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
new file mode 100644
index 0000000000..5812de4b30
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
@@ -0,0 +1,202 @@
+load_model: "speech_asr_lm_model.tflite"
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 3
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 4
+ input: "8409"
+ input: "1488"
+ output: "0.601841"
+}
+invoke {
+ id: 5
+ input: "1488"
+ input: "63981"
+ output: "-0.314846"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 6
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 7
+ input: "8409"
+ input: "3082"
+ output: "-3.63721"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 8
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 9
+ input: "8409"
+ input: "18965"
+ output: "-6.93985"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 13
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 14
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 15
+ input: "914"
+ input: "63981"
+ output: "-3.82091"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 19
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 20
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 21
+ input: "914"
+ input: "48619"
+ output: "-4.02131"
+}
+invoke {
+ id: 22
+ input: "48619"
+ input: "63981"
+ output: "-0.677399"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 26
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 27
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 28
+ input: "914"
+ input: "4700"
+ output: "-4.056"
+}
+invoke {
+ id: 29
+ input: "4700"
+ input: "63981"
+ output: "0.415889"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 30
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 31
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+invoke {
+ id: 32
+ input: "914"
+ input: "51923"
+ output: "-14.1147"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 34
+ input: "63982"
+ input: "5520"
+ output: "-4.56971"
+}
+invoke {
+ id: 35
+ input: "5520"
+ input: "16318"
+ output: "-1.54815"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 36
+ input: "63982"
+ input: "5520"
+ output: "-4.56971"
+}
+invoke {
+ id: 37
+ input: "5520"
+ input: "28303"
+ output: "-14.0947"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 38
+ input: "63982"
+ input: "12451"
+ output: "-6.24243"
+}
+invoke {
+ id: 39
+ input: "12451"
+ input: "752"
+ output: "0.0700736"
+}
+invoke {
+ id: 40
+ input: "752"
+ input: "11"
+ output: "-1.72744"
+}
+invoke {
+ id: 41
+ input: "11"
+ input: "19454"
+ output: "-3.19211"
+}
+invoke {
+ id: 42
+ input: "19454"
+ input: "16989"
+ output: "-4.01684"
+}
+invoke {
+ id: 43
+ input: "16989"
+ input: "40168"
+ output: "-8.91317"
+}
+invoke {
+ id: 44
+ input: "40168"
+ input: "63981"
+ output: "-0.675377"
+}