From 82b5d883774382f0e9cddc828f0d2f85af70f2ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 11 Jan 2018 12:15:24 -0800 Subject: Add speech endpointer model test for TF-lite. PiperOrigin-RevId: 181644178 --- .../lite/models/speech_endpointer_model_test.cc | 104 +++++++++++++++++++++ .../contrib/lite/models/testdata/g3doc/README.md | 17 ++++ .../lite/models/testdata/g3doc/endpointer.svg | 4 + 3 files changed, 125 insertions(+) create mode 100644 tensorflow/contrib/lite/models/speech_endpointer_model_test.cc create mode 100644 tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg (limited to 'tensorflow/contrib/lite/models') diff --git a/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc b/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc new file mode 100644 index 0000000000..f7e136113a --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech EndPointer model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "testing/base/public/googletest.h" +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 28; +constexpr int kLstmLayer1CellStateTensor = 29; +constexpr int kLstmLayer2OutputStateTensor = 49; +constexpr int kLstmLayer2CellStateTensor = 50; +constexpr int kModelOutputTensor = 58; + +TEST(SpeechEndpointer, EndpointerTest) { + // Read the model. + string tflite_file_path = + StrCat(TestDataPath(), "/", "speech_endpointer_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + StrCat(TestDataPath(), "/", "speech_endpointer_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + StrCat(TestDataPath(), "/", "speech_endpointer_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md index 46b24248f0..8a023ea869 100644 --- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md +++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md @@ -75,6 +75,20 @@ The corresponding parameters as shown in the figure. ![asr_lm_model](asr_lm.svg "ASR LM model") +### Endpointer Model + +The endpointer model is the neural network model for predicting end of speech +in an utterance. More precisely, it generates posterior probabilities of various +events that allow detection of speech start and end events. +It has an input size of 40 (float) which are speech frontend features +(log-mel filterbanks), and an output size of four corresponding to: +speech, intermediate non-speech, initial non-speech, and final non-speech. +The model consists of a convolutional layer, followed by a fully-connected +layer, two LSTM layers, and two additional fully-connected layers. +The corresponding parameters as shown in the figure. +![endpointer_model](endpointer.svg "Endpointer model") + + ## Speech models test input/output generation As mentioned above the input to models are generated from a pre-processing @@ -115,6 +129,9 @@ test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/li [ASR AM model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc) +[Endpointer model +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc) + ## Android Support The models have been tested on Android phones, using the following tests: diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg b/tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg new file mode 100644 index 0000000000..6033bdc529 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg @@ -0,0 +1,4 @@ + + + + -- cgit v1.2.3