diff options
author | 2018-01-23 15:08:26 -0800 | |
---|---|---|
committer | 2018-01-23 15:12:38 -0800 | |
commit | 58d227d36aa17d038c70c94787f415b1cd64d982 (patch) | |
tree | 8f996478aa6b51345d4620c8d07f31c16da132bd /tensorflow/contrib/lite/models | |
parent | 68a9ee1d4e041d7690a949718b2651b035a6bfad (diff) |
Switch models/ to use new test-specification format
PiperOrigin-RevId: 182999650
Diffstat (limited to 'tensorflow/contrib/lite/models')
-rw-r--r-- | tensorflow/contrib/lite/models/speech_test.cc | 189 | ||||
-rw-r--r-- | tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec | 202 |
2 files changed, 391 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc new file mode 100644 index 0000000000..daa8c3100b --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech models (Hotword, SpeakerId) using TFLite Ops. + +#include <memory> +#include <string> + +#include <fstream> + +#include "testing/base/public/googletest.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +namespace tflite { +namespace { + +const char kDataPath[] = "third_party/tensorflow/contrib/lite/models/testdata/"; + +bool Init(const string& in_file_name, testing::TfLiteDriver* driver, + std::ifstream* in_file) { + driver->SetModelBaseDir(kDataPath); + in_file->open(string(kDataPath) + in_file_name, std::ifstream::in); + return in_file->is_open(); +} + +// Converts a set of test files provided by the speech team into a single +// test_spec. Input CSV files are supposed to contain a number of sequences per +// line. Each sequence maps to a single invocation of the interpreter and the +// output tensor after all sequences have run is compared to the corresponding +// line in the output CSV file. +bool ConvertCsvData(const string& model_name, const string& in_name, + const string& out_name, const string& input_tensor, + const string& output_tensor, + const string& persistent_tensors, int sequence_size, + std::ostream* out) { + auto data_path = [](const string& s) { return string(kDataPath) + s; }; + + *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl; + + *out << "init_state: \"" << persistent_tensors << "\"" << std::endl; + + string in_file_name = data_path(in_name); + std::ifstream in_file(in_file_name); + if (!in_file.is_open()) { + std::cerr << "Failed to open " << in_file_name << std::endl; + return false; + } + string out_file_name = data_path(out_name); + std::ifstream out_file(out_file_name); + if (!out_file.is_open()) { + std::cerr << "Failed to open " << out_file_name << std::endl; + return false; + } + + int invocation_count = 0; + string in_values; + while (std::getline(in_file, in_values, '\n')) { + std::vector<string> input = testing::Split<string>(in_values, ","); + int num_sequences = input.size() / sequence_size; + + for (int j = 0; j < num_sequences; ++j) { + *out << "invoke {" << std::endl; + *out << " id: " << invocation_count << std::endl; + *out << " input: \""; + for (int k = 0; k < sequence_size; ++k) { + *out << input[k + j * sequence_size] << ","; + } + *out << "\"" << std::endl; + + if (j == num_sequences - 1) { + string out_values; + if (!std::getline(out_file, out_values, '\n')) { + std::cerr << "Not enough lines in " << out_file_name << std::endl; + return false; + } + *out << " output: \"" << out_values << "\"" << std::endl; + } + + *out << "}" << std::endl; + ++invocation_count; + } + } + return true; +} + +TEST(SpeechTest, HotwordOkGoogleRank1Test) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0", + /*output_tensor=*/"18", /*persistent_tensors=*/"4", + /*sequence_size=*/40, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, HotwordOkGoogleRank2Test) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17", + /*output_tensor=*/"18", /*persistent_tensors=*/"1", + /*sequence_size=*/40, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, SpeakerIdOkGoogleTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv", + "speech_speakerid_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"66", + /*persistent_tensors=*/"19,20,40,41,61,62", + /*sequence_size=*/80, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, AsrAmTest) { + std::stringstream os; + ASSERT_TRUE( + ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", + "speech_asr_am_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"109", + /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104", + /*sequence_size=*/320, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +// The original version of speech_asr_lm_model_test.cc ran a few sequences +// through the interpreter and stored the sum of all the output, which was them +// compared for correctness. In this test we are comparing all the intermediate +// results. +TEST(SpeechTest, AsrLmTest) { + std::ifstream in_file; + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file)); + ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, EndpointerTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv", + "speech_endpointer_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"58", + /*persistent_tensors=*/"28,29,49,50", + /*sequence_size=*/320, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, TtsTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite", + "speech_tts_model_in.csv", + "speech_tts_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"74", + /*persistent_tensors=*/"25,26,46,47,67,68,73", + /*sequence_size=*/334, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec new file mode 100644 index 0000000000..5812de4b30 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec @@ -0,0 +1,202 @@ +load_model: "speech_asr_lm_model.tflite" +init_state: "21,22,42,43,63,64" +invoke { + id: 3 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 4 + input: "8409" + input: "1488" + output: "0.601841" +} +invoke { + id: 5 + input: "1488" + input: "63981" + output: "-0.314846" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 6 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 7 + input: "8409" + input: "3082" + output: "-3.63721" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 8 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 9 + input: "8409" + input: "18965" + output: "-6.93985" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 13 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 14 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 15 + input: "914" + input: "63981" + output: "-3.82091" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 19 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 20 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 21 + input: "914" + input: "48619" + output: "-4.02131" +} +invoke { + id: 22 + input: "48619" + input: "63981" + output: "-0.677399" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 26 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 27 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 28 + input: "914" + input: "4700" + output: "-4.056" +} +invoke { + id: 29 + input: "4700" + input: "63981" + output: "0.415889" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 30 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 31 + input: "12516" + input: "914" + output: "-0.407277" +invoke { + id: 32 + input: "914" + input: "51923" + output: "-14.1147" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 34 + input: "63982" + input: "5520" + output: "-4.56971" +} +invoke { + id: 35 + input: "5520" + input: "16318" + output: "-1.54815" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 36 + input: "63982" + input: "5520" + output: "-4.56971" +} +invoke { + id: 37 + input: "5520" + input: "28303" + output: "-14.0947" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 38 + input: "63982" + input: "12451" + output: "-6.24243" +} +invoke { + id: 39 + input: "12451" + input: "752" + output: "0.0700736" +} +invoke { + id: 40 + input: "752" + input: "11" + output: "-1.72744" +} +invoke { + id: 41 + input: "11" + input: "19454" + output: "-3.19211" +} +invoke { + id: 42 + input: "19454" + input: "16989" + output: "-4.01684" +} +invoke { + id: 43 + input: "16989" + input: "40168" + output: "-8.91317" +} +invoke { + id: 44 + input: "40168" + input: "63981" + output: "-0.675377" +} |