aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-11-10 10:35:35 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch)
tree9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite/models
parent7ac140a5845553275427162aabd9d54987144b4a (diff)
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite/models')
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD15
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc119
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc100
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize.cc105
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc90
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict.cc174
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc183
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.cc116
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h80
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc150
-rw-r--r--tensorflow/contrib/lite/models/speech_hotword_model_test.cc115
-rw-r--r--tensorflow/contrib/lite/models/speech_speakerid_model_test.cc114
-rw-r--r--tensorflow/contrib/lite/models/speech_terse_am_model_test.cc127
-rw-r--r--tensorflow/contrib/lite/models/speech_tts_model_test.cc116
-rw-r--r--tensorflow/contrib/lite/models/test_utils.h84
15 files changed, 1688 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
new file mode 100644
index 0000000000..fbdf19f205
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -0,0 +1,15 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
new file mode 100644
index 0000000000..1c422b659a
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Convert a list of strings to integers via hashing.
+// Input:
+// Input[0]: A list of ngrams. string[num of input]
+//
+// Output:
+// Output[0]: Hashed features. int32[num of input]
+// Output[1]: Weights. float[num of input]
+
+#include <algorithm>
+#include <map>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include <farmhash.h>
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace extract {
+
+static const int kMaxDimension = 1000000;
+static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"};
+
+bool Equals(const string& x, const tflite::StringRef& strref) {
+ if (strref.len != x.length()) {
+ return false;
+ }
+ if (strref.len > 0) {
+ int r = memcmp(strref.str, x.data(), strref.len);
+ return r == 0;
+ }
+ return true;
+}
+
+bool IsValidNgram(const tflite::StringRef& strref) {
+ for (const auto& s : kBlacklistNgram) {
+ if (Equals(s, strref)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
+ TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int dim = input->dims->data[0];
+ if (dim == 0) {
+ // TFLite non-string output should have size greater than 0.
+ dim = 1;
+ }
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString);
+ outputSize1->data[0] = dim;
+ outputSize2->data[0] = dim;
+ context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1);
+ context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int num_strings = tflite::GetStringCount(input);
+ TfLiteTensor* label = GetOutput(context, node, 0);
+ TfLiteTensor* weight = GetOutput(context, node, 1);
+
+ std::map<int64, int> feature_id_counts;
+ for (int i = 0; i < num_strings; i++) {
+ // Use fingerprint of feature name as id.
+ auto strref = tflite::GetString(input, i);
+ if (!IsValidNgram(strref)) {
+ label->data.i32[i] = 0;
+ weight->data.i32[i] = 0;
+ continue;
+ }
+
+ int64 feature_id =
+ ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
+
+ label->data.i32[i] = static_cast<int32>(feature_id);
+ weight->data.f[i] =
+ std::count(strref.str, strref.str + strref.len, ' ') + 1;
+ }
+ // Explicitly set an empty result to make preceding ops run.
+ if (num_strings == 0) {
+ label->data.i32[0] = 0;
+ weight->data.i32[0] = 0;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace extract
+
+TfLiteRegistration* Register_EXTRACT_FEATURES() {
+ static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare,
+ extract::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
new file mode 100644
index 0000000000..9b8676bab6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include <farmhash.h>
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_EXTRACT_FEATURES();
+
+namespace {
+
+using ::testing::ElementsAre;
+
+class ExtractFeatureOpModel : public SingleOpModel {
+ public:
+ explicit ExtractFeatureOpModel(const std::vector<string>& input) {
+ input_ = AddInput(TensorType_STRING);
+ signature_ = AddOutput(TensorType_INT32);
+ weight_ = AddOutput(TensorType_FLOAT32);
+
+ SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, input);
+ }
+
+ std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
+ std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
+
+ private:
+ int input_;
+ int signature_;
+ int weight_;
+};
+
+int CalcFeature(const string& str) {
+ return ::util::Fingerprint64(str) % 1000000;
+}
+
+TEST(ExtractFeatureOpTest, RegularInput) {
+ ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(),
+ ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
+ CalcFeature("Hi !"), CalcFeature("!"),
+ CalcFeature("! <E>"), 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
+}
+
+TEST(ExtractFeatureOpTest, OneInput) {
+ ExtractFeatureOpModel m({"Hi"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(1));
+}
+
+TEST(ExtractFeatureOpTest, ZeroInput) {
+ ExtractFeatureOpModel m({});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0));
+}
+
+TEST(ExtractFeatureOpTest, AllBlacklistInput) {
+ ExtractFeatureOpModel m({"<S>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
new file mode 100644
index 0000000000..d0dc2a35a7
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Normalize the string input.
+//
+// Input:
+// Input[0]: One sentence. string[1]
+//
+// Output:
+// Output[0]: Normalized sentence. string[1]
+//
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/strip.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace normalize {
+
+// Predictor transforms.
+const char kPunctuationsRegex[] = "[.*()\"]";
+
+const std::map<string, string>* kRegexTransforms =
+ new std::map<string, string>({
+ {"([^\\s]+)n't", "\\1 not"},
+ {"([^\\s]+)'nt", "\\1 not"},
+ {"([^\\s]+)'ll", "\\1 will"},
+ {"([^\\s]+)'re", "\\1 are"},
+ {"([^\\s]+)'ve", "\\1 have"},
+ {"i'm", "i am"},
+ });
+
+static const char kStartToken[] = "<S>";
+static const char kEndToken[] = "<E>";
+static const int32 kMaxInputChars = 300;
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
+
+ string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
+ absl::StripAsciiWhitespace(&result);
+ // Do not remove commas, semi-colons or colons from the sentences as they can
+ // indicate the beginning of a new clause.
+ RE2::GlobalReplace(&result, kPunctuationsRegex, "");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
+ "\\1\\2");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
+ for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
+ iter++) {
+ RE2::GlobalReplace(&result, iter->first, iter->second);
+ }
+
+ // Treat questions & interjections as special cases.
+ RE2::GlobalReplace(&result, "([?])+", "\\1");
+ RE2::GlobalReplace(&result, "([!])+", "\\1");
+ RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
+ RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
+
+ RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
+ RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
+ absl::StripAsciiWhitespace(&result);
+
+ // Add start and end token.
+ // Truncate input to maximum allowed size.
+ if (result.length() <= kMaxInputChars) {
+ absl::StrAppend(&result, " ", kEndToken);
+ } else {
+ result = result.substr(0, kMaxInputChars);
+ }
+ result = absl::StrCat(kStartToken, " ", result);
+
+ tflite::DynamicBuffer buf;
+ buf.AddString(result.data(), result.length());
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+}
+
+} // namespace normalize
+
+TfLiteRegistration* Register_NORMALIZE() {
+ static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
new file mode 100644
index 0000000000..4d35dba9a6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_NORMALIZE();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class NormalizeOpModel : public SingleOpModel {
+ public:
+ explicit NormalizeOpModel(const string& input) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetCustomOp("Normalize", {}, Register_NORMALIZE);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, {input});
+ }
+
+ std::vector<string> GetStringOutput() {
+ TfLiteTensor* output = interpreter_->tensor(output_);
+ int num = GetStringCount(output);
+ std::vector<string> result(num);
+ for (int i = 0; i < num; i++) {
+ auto ref = GetString(output, i);
+ result[i] = string(ref.str, ref.len);
+ }
+ return result;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(NormalizeOpTest, RegularInput) {
+ NormalizeOpModel m("I'm good; you're welcome");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(),
+ ElementsAreArray({"<S> i am good; you are welcome <E>"}));
+}
+
+TEST(NormalizeOpTest, OneInput) {
+ NormalizeOpModel m("Hi!!!!");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"}));
+}
+
+TEST(NormalizeOpTest, EmptyInput) {
+ NormalizeOpModel m("");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"}));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
new file mode 100644
index 0000000000..7b23adb990
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Lookup projected hash signatures in Predictor model,
+// output predicted labels and weights in decreasing order.
+//
+// Input:
+// Input[0]: A list of hash signatures. int32[num of input]
+// Input[1]: Hash signature keys in the model. int32[keys of model]
+// Input[2]: Labels in the model. int32[keys of model, item per entry]
+// Input[3]: Weights in the model. float[keys of model, item per entry]
+//
+// Output:
+// Output[0]: Predicted labels. int32[num of output]
+// Output[1]: Predicted weights. float[num of output]
+//
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace predict {
+
+struct PredictOption {
+ int32_t num_output;
+ float weight_threshold;
+
+ static PredictOption* Cast(void* ptr) {
+ return reinterpret_cast<PredictOption*>(ptr);
+ }
+};
+
+bool WeightGreater(const std::pair<int32_t, float>& a,
+ const std::pair<int32_t, float>& b) {
+ return a.second > b.second;
+}
+
+void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
+ if (custom_option == nullptr || length != sizeof(PredictOption)) {
+ fprintf(stderr, "No Custom option set\n");
+ exit(1);
+ }
+ PredictOption* option = new PredictOption;
+ int offset = 0;
+ option->num_output =
+ *reinterpret_cast<const int32_t*>(custom_option + offset);
+ offset += sizeof(int32_t);
+ option->weight_threshold =
+ *reinterpret_cast<const float*>(custom_option + offset);
+ return reinterpret_cast<void*>(option);
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete PredictOption::Cast(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_label->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_weight->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
+ model_weight->dims->data[1]);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
+
+ TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
+ label_size->data[0] = option->num_output;
+ TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
+ weight_size->data[0] = option->num_output;
+ TfLiteStatus status =
+ context->ResizeTensor(context, output_label, label_size);
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ return context->ResizeTensor(context, output_weight, weight_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+
+ // Aggregate by key
+ std::unordered_map<int32_t, float> aggregation;
+ const int num_input = lookup->dims->data[0];
+ const int num_rows = model_key->dims->data[0];
+ const int items = model_label->dims->data[1];
+ int* model_key_end = model_key->data.i32 + num_rows;
+
+ for (int i = 0; i < num_input; i++) {
+ int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
+ lookup->data.i32[i]);
+ if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
+ int idx = ptr - model_key->data.i32;
+ for (int j = 0; j < items; j++) {
+ aggregation[model_label->data.i32[idx * items + j]] +=
+ model_weight->data.f[idx * items + j] / num_input;
+ }
+ }
+ }
+
+ // Sort by value
+ std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
+ aggregation.end());
+ std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ for (int i = 0; i < output_label->dims->data[0]; i++) {
+ if (i >= sorted_labels.size() ||
+ sorted_labels[i].second < option->weight_threshold) {
+ // Set -1 to avoid lookup message with id 0, which is set for backoff.
+ output_label->data.i32[i] = -1;
+ output_weight->data.f[i] = 0.0f;
+ } else {
+ output_label->data.i32[i] = sorted_labels[i].first;
+ output_weight->data.f[i] = sorted_labels[i].second;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace predict
+
+TfLiteRegistration* Register_PREDICT() {
+ static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
+ predict::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
new file mode 100644
index 0000000000..e97c58cbd1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_PREDICT();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class PredictOpModel : public SingleOpModel {
+ public:
+ PredictOpModel(std::initializer_list<int> input_signature_shape,
+ std::initializer_list<int> key_shape,
+ std::initializer_list<int> labelweight_shape, int num_output,
+ float threshold) {
+ input_signature_ = AddInput(TensorType_INT32);
+ model_key_ = AddInput(TensorType_INT32);
+ model_label_ = AddInput(TensorType_INT32);
+ model_weight_ = AddInput(TensorType_FLOAT32);
+ output_label_ = AddOutput(TensorType_INT32);
+ output_weight_ = AddOutput(TensorType_FLOAT32);
+
+ std::vector<uint8_t> predict_option;
+ writeInt32(num_output, &predict_option);
+ writeFloat32(threshold, &predict_option);
+ SetCustomOp("Predict", predict_option, Register_PREDICT);
+ BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
+ labelweight_shape}});
+ }
+
+ void SetInputSignature(std::initializer_list<int> data) {
+ PopulateTensor<int>(input_signature_, data);
+ }
+
+ void SetModelKey(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_key_, data);
+ }
+
+ void SetModelLabel(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_label_, data);
+ }
+
+ void SetModelWeight(std::initializer_list<float> data) {
+ PopulateTensor<float>(model_weight_, data);
+ }
+
+ std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
+ std::vector<float> GetWeight() {
+ return ExtractVector<float>(output_weight_);
+ }
+
+ void writeFloat32(float value, std::vector<uint8_t>* data) {
+ union {
+ float v;
+ uint8_t r[4];
+ } float_to_raw;
+ float_to_raw.v = value;
+ for (unsigned char i : float_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ void writeInt32(int32_t value, std::vector<uint8_t>* data) {
+ union {
+ int32_t v;
+ uint8_t r[4];
+ } int32_to_raw;
+ int32_to_raw.v = value;
+ for (unsigned char i : int32_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ private:
+ int input_signature_;
+ int model_key_;
+ int model_label_;
+ int model_weight_;
+ int output_label_;
+ int output_weight_;
+};
+
+TEST(PredictOpTest, AllLabelsAreValid) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
+}
+
+TEST(PredictOpTest, MoreLabelsThanRequired) {
+ PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
+}
+
+TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
+}
+
+TEST(PredictOpTest, NoneLabelPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+TEST(PredictOpTest, OnlyOneLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
+}
+
+TEST(PredictOpTest, NoLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({5, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc
new file mode 100644
index 0000000000..a28222213e
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include "absl/strings/str_split.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+// Split sentence into segments (using punctuation).
+std::vector<string> SplitSentence(const string& input) {
+ string result(input);
+
+ RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
+ RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
+ RE2::GlobalReplace(&result, "[ ]+", " ");
+ RE2::GlobalReplace(&result, "\t+$", "");
+
+ return strings::Split(result, '\t');
+}
+
+// Predict with TfLite model.
+void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
+ std::map<string, float>* response_map) {
+ {
+ TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
+ tflite::DynamicBuffer buf;
+ buf.AddString(sentence.data(), sentence.length());
+ buf.WriteToTensor(input);
+ interpreter->AllocateTensors();
+
+ interpreter->Invoke();
+
+ TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
+ TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
+
+ for (int i = 0; i < confidence->dims->data[0]; i++) {
+ float weight = confidence->data.f[i];
+ auto response_text = tflite::GetString(messages, i);
+ if (response_text.len > 0) {
+ (*response_map)[string(response_text.str, response_text.len)] += weight;
+ }
+ }
+ }
+}
+
+void GetSegmentPredictions(
+ const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses) {
+ // Initialize interpreter
+ std::unique_ptr<::tflite::Interpreter> interpreter;
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ ::tflite::InterpreterBuilder(model, resolver)(&interpreter);
+
+ if (!model.initialized()) {
+ fprintf(stderr, "Failed to mmap model \n");
+ return;
+ }
+
+ // Execute Tflite Model
+ std::map<string, float> response_map;
+ std::vector<string> sentences;
+ for (const string& str : input) {
+ std::vector<string> splitted_str = SplitSentence(str);
+ sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
+ }
+ for (const auto& sentence : sentences) {
+ ExecuteTfLite(sentence, interpreter.get(), &response_map);
+ }
+
+ // Generate the result.
+ for (const auto& iter : response_map) {
+ PredictorResponse prediction(iter.first, iter.second);
+ predictor_responses->emplace_back(prediction);
+ }
+ std::sort(predictor_responses->begin(), predictor_responses->end(),
+ [](const PredictorResponse& a, const PredictorResponse& b) {
+ return a.GetScore() > b.GetScore();
+ });
+
+ // Add backoff response.
+ for (const string& backoff : config.backoff_responses) {
+ if (predictor_responses->size() >= config.num_response) {
+ break;
+ }
+ predictor_responses->push_back({backoff, config.backoff_confidence});
+ }
+}
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
new file mode 100644
index 0000000000..3b9a2b32e1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+const int kDefaultNumResponse = 10;
+const float kDefaultBackoffConfidence = 1e-4;
+
+class PredictorResponse;
+struct SmartReplyConfig;
+
+// With a given string as input, predict the response with a Tflite model.
+// When config.backoff_response is not empty, predictor_responses will be filled
+// with messagees from backoff response.
+void GetSegmentPredictions(const std::vector<string>& input,
+ const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses);
+
+// Data object used to hold a single predictor response.
+// It includes messages, and confidence.
+class PredictorResponse {
+ public:
+ PredictorResponse(const string& response_text, float score) {
+ response_text_ = response_text;
+ prediction_score_ = score;
+ }
+
+ // Accessor methods.
+ const string& GetText() const { return response_text_; }
+ float GetScore() const { return prediction_score_; }
+
+ private:
+ string response_text_ = "";
+ float prediction_score_ = 0.0;
+};
+
+// Configurations for SmartReply.
+struct SmartReplyConfig {
+ // Maximum responses to return.
+ int num_response;
+ // Default confidence for backoff responses.
+ float backoff_confidence;
+ // Backoff responses are used when predicted responses cannot fulfill the
+ // list.
+ const std::vector<string>& backoff_responses;
+
+ SmartReplyConfig(std::vector<string> backoff_responses)
+ : num_response(kDefaultNumResponse),
+ backoff_confidence(kDefaultBackoffConfidence),
+ backoff_responses(backoff_responses) {}
+};
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
new file mode 100644
index 0000000000..2fa9923bc9
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include <fstream>
+#include <unordered_set>
+
+#include "base/logging.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+namespace {
+
+const char kModelName[] = "smartreply_ondevice_model.bin";
+const char kSamples[] = "smartreply_samples.tsv";
+
+MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
+ bool has_expected_response = false;
+ for (const auto &item : *arg) {
+ const string &response = item.GetText();
+ if (expected_response.find(response) != expected_response.end()) {
+ has_expected_response = true;
+ break;
+ }
+ }
+ return has_expected_response;
+}
+
+class PredictorTest : public ::testing::Test {
+ protected:
+ PredictorTest() {
+ model_ = tflite::FlatBufferModel::BuildFromFile(
+ StrCat(TestDataPath(), "/", kModelName).c_str());
+ CHECK(model_);
+ }
+ ~PredictorTest() override {}
+
+ std::unique_ptr<::tflite::FlatBufferModel> model_;
+};
+
+TEST_F(PredictorTest, GetSegmentPredictions) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(
+ &predictions,
+ IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
+}
+
+TEST_F(PredictorTest, TestTwoSentences) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
+ &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ {"Hi, how are you doing?"})));
+}
+
+TEST_F(PredictorTest, TestBackoff) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_EQ(predictions.size(), 0);
+
+ // Backoff responses are returned in order.
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}},
+ &predictions);
+ EXPECT_EQ(predictions.size(), 2);
+ EXPECT_EQ(predictions[0].GetText(), "Yes");
+ EXPECT_EQ(predictions[1].GetText(), "Ok");
+}
+
+TEST_F(PredictorTest, BatchTest) {
+ int total_items = 0;
+ int total_responses = 0;
+ int total_triggers = 0;
+
+ string line;
+ std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
+ while (std::getline(fin, line)) {
+ const std::vector<string> &fields = strings::Split(line, '\t');
+ if (fields.empty()) {
+ continue;
+ }
+
+ // Parse sample file and predict
+ const string &msg = fields[0];
+ std::vector<PredictorResponse> predictions;
+ GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
+
+ // Validate response and generate stats.
+ total_items++;
+ total_responses += predictions.size();
+ if (!predictions.empty()) {
+ total_triggers++;
+ }
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ fields.begin() + 1, fields.end())));
+ }
+
+ LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
+ LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
+ EXPECT_EQ(total_triggers, total_items);
+}
+
+} // namespace
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
new file mode 100644
index 0000000000..f5d1f436bc
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech Hotword model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+void RunTest(int model_input_tensor, int svdf_layer_state_tensor,
+ int model_output_tensor, const string& model_name,
+ const string& golden_in_name, const string& golden_out_name) {
+ // Read the model.
+ string tflite_file_path = file::JoinPath(TestDataPath(), model_name);
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Reset the SVDF layer state.
+ memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0,
+ interpreter->tensor(svdf_layer_state_tensor)->bytes);
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path = file::JoinPath(TestDataPath(), golden_in_name);
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), golden_out_name);
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(model_input_tensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(model_input_tensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(model_output_tensor)->dims->data[1];
+ const int input_sequence_size =
+ input_frames[0].size() / (speech_input_size * speech_batch_size);
+ float* input_ptr = interpreter->tensor(model_input_tensor)->data.f;
+ float* output_ptr = interpreter->tensor(model_output_tensor)->data.f;
+
+ // The first layer (SVDF) input size is 40 (speech_input_size). Each speech
+ // input frames for this model is 1280 floats, which can be fed to input in a
+ // sequence of size 32 (input_sequence_size).
+ for (int i = 0; i < TestInputSize(input_frames); i++) {
+ int frame_ptr = 0;
+ for (int s = 0; s < input_sequence_size; s++) {
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ interpreter->Invoke();
+ }
+ // After the whole frame (1280 floats) is fed, we can check the output frame
+ // matches with the golden output frame.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+TEST(SpeechHotword, OkGoogleTestRank1) {
+ constexpr int kModelInputTensor = 0;
+ constexpr int kSvdfLayerStateTensor = 4;
+ constexpr int kModelOutputTensor = 18;
+
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank1.csv");
+}
+
+TEST(SpeechHotword, OkGoogleTestRank2) {
+ constexpr int kModelInputTensor = 17;
+ constexpr int kSvdfLayerStateTensor = 1;
+ constexpr int kModelOutputTensor = 18;
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank2.csv");
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
new file mode 100644
index 0000000000..687cfab0b2
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech SpeakerId model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kModelOutputTensor = 66;
+
+TEST(SpeechSpeakerId, OkGoogleTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
new file mode 100644
index 0000000000..30d89a1354
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TERSE AM model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kLstmLayer4OutputStateTensor = 82;
+constexpr int kLstmLayer4CellStateTensor = 83;
+constexpr int kLstmLayer5OutputStateTensor = 103;
+constexpr int kLstmLayer5CellStateTensor = 104;
+constexpr int kModelOutputTensor = 109;
+
+TEST(SpeechTerseAm, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5CellStateTensor)->bytes);
+
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
new file mode 100644
index 0000000000..e6f2673a42
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TTS model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 25;
+constexpr int kLstmLayer1CellStateTensor = 26;
+constexpr int kLstmLayer2OutputStateTensor = 46;
+constexpr int kLstmLayer2CellStateTensor = 47;
+constexpr int kLstmLayer3OutputStateTensor = 67;
+constexpr int kLstmLayer3CellStateTensor = 68;
+constexpr int kRnnLayerHiddenStateTensor = 73;
+constexpr int kModelOutputTensor = 74;
+
+TEST(SpeechTTS, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0,
+ interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes);
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h
new file mode 100644
index 0000000000..b2596babd0
--- /dev/null
+++ b/tensorflow/contrib/lite/models/test_utils.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace models {
+using Frames = std::vector<std::vector<float>>;
+} // namespace models
+} // namespace tflite
+
+#ifndef __ANDROID__
+#include "file/base/path.h"
+#include "tensorflow/core/platform/test.h"
+
+inline string TestDataPath() {
+ return string(file::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
+ "contrib/lite/models/testdata/"));
+}
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ return input_frames.size();
+}
+#else
+inline string TestDataPath() {
+ return string("third_party/tensorflow/contrib/lite/models/testdata/");
+}
+
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ // Android TAP is very slow, we only test the first 20 frames.
+ return 20;
+}
+#endif
+
+namespace tflite {
+namespace models {
+
+// Read float data from a comma-separated file:
+// Each line will be read into a float vector.
+// The return result will be a vector of float vectors.
+void ReadFrames(const string& csv_file_path, Frames* frames) {
+ std::ifstream csv_file(csv_file_path);
+ string line;
+ while (std::getline(csv_file, line, '\n')) {
+ std::vector<float> fields;
+ // Used by strtok_r internaly for successive calls on the same string.
+ char* save_ptr = nullptr;
+
+ // Tokenize the line.
+ char* next_token =
+ strtok_r(const_cast<char*>(line.c_str()), ",", &save_ptr);
+ while (next_token != nullptr) {
+ float f = strtod(next_token, nullptr);
+ fields.push_back(f);
+ next_token = strtok_r(nullptr, ",", &save_ptr);
+ }
+ frames->push_back(fields);
+ }
+ csv_file.close();
+}
+
+} // namespace models
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_