aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
blob: f69cae8d2cb08678f9eec8c9b9d653cfce55bd2e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Unit test for speech Hotword model using TFLite Ops.

#include <string.h>

#include <memory>
#include <string>

#include "base/logging.h"
#include "testing/base/public/googletest.h"
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/models/test_utils.h"

namespace tflite {
namespace models {

void RunTest(int model_input_tensor, int svdf_layer_state_tensor,
             int model_output_tensor, const string& model_name,
             const string& golden_in_name, const string& golden_out_name) {
  // Read the model.
  string tflite_file_path = StrCat(TestDataPath(), "/", model_name);
  auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
  CHECK(model) << "Failed to read model from file " << tflite_file_path;

  // Initialize the interpreter.
  ops::builtin::BuiltinOpResolver builtins;
  std::unique_ptr<Interpreter> interpreter;
  InterpreterBuilder(*model, builtins)(&interpreter);
  CHECK(interpreter != nullptr);
  interpreter->AllocateTensors();

  // Reset the SVDF layer state.
  memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0,
         interpreter->tensor(svdf_layer_state_tensor)->bytes);

  // Load the input frames.
  Frames input_frames;
  const string input_file_path = StrCat(TestDataPath(), "/", golden_in_name);
  ReadFrames(input_file_path, &input_frames);

  // Load the golden output results.
  Frames output_frames;
  const string output_file_path = StrCat(TestDataPath(), "/", golden_out_name);
  ReadFrames(output_file_path, &output_frames);

  const int speech_batch_size =
      interpreter->tensor(model_input_tensor)->dims->data[0];
  const int speech_input_size =
      interpreter->tensor(model_input_tensor)->dims->data[1];
  const int speech_output_size =
      interpreter->tensor(model_output_tensor)->dims->data[1];
  const int input_sequence_size =
      input_frames[0].size() / (speech_input_size * speech_batch_size);
  float* input_ptr = interpreter->tensor(model_input_tensor)->data.f;
  float* output_ptr = interpreter->tensor(model_output_tensor)->data.f;

  // The first layer (SVDF) input size is 40 (speech_input_size). Each speech
  // input frames for this model is 1600 floats, which can be fed to input in a
  // sequence of size 40 (input_sequence_size).
  for (int i = 0; i < TestInputSize(input_frames); i++) {
    int frame_ptr = 0;
    for (int s = 0; s < input_sequence_size; s++) {
      for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
        input_ptr[k] = input_frames[i][frame_ptr++];
      }
      interpreter->Invoke();
    }
    // After the whole frame (1280 floats) is fed, we can check the output frame
    // matches with the golden output frame.
    for (int k = 0; k < speech_output_size; k++) {
      ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
    }
  }
}

TEST(SpeechHotword, OkGoogleTestRank1) {
  constexpr int kModelInputTensor = 0;
  constexpr int kSvdfLayerStateTensor = 4;
  constexpr int kModelOutputTensor = 18;

  RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
          "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
          "speech_hotword_model_out_rank1.csv");
}

TEST(SpeechHotword, OkGoogleTestRank2) {
  constexpr int kModelInputTensor = 17;
  constexpr int kSvdfLayerStateTensor = 1;
  constexpr int kModelOutputTensor = 18;
  RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
          "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
          "speech_hotword_model_out_rank2.csv");
}

}  // namespace models
}  // namespace tflite