diff options
Diffstat (limited to 'tensorflow/contrib/lite/models')
15 files changed, 1688 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD new file mode 100644 index 0000000000..fbdf19f205 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc new file mode 100644 index 0000000000..1c422b659a --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Convert a list of strings to integers via hashing. +// Input: +// Input[0]: A list of ngrams. string[num of input] +// +// Output: +// Output[0]: Hashed features. int32[num of input] +// Output[1]: Weights. float[num of input] + +#include <algorithm> +#include <map> +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" +#include <farmhash.h> + +namespace tflite { +namespace ops { +namespace custom { + +namespace extract { + +static const int kMaxDimension = 1000000; +static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"}; + +bool Equals(const string& x, const tflite::StringRef& strref) { + if (strref.len != x.length()) { + return false; + } + if (strref.len > 0) { + int r = memcmp(strref.str, x.data(), strref.len); + return r == 0; + } + return true; +} + +bool IsValidNgram(const tflite::StringRef& strref) { + for (const auto& s : kBlacklistNgram) { + if (Equals(s, strref)) { + return false; + } + } + return true; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1); + TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1); + TfLiteTensor* input = GetInput(context, node, 0); + int dim = input->dims->data[0]; + if (dim == 0) { + // TFLite non-string output should have size greater than 0. + dim = 1; + } + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString); + outputSize1->data[0] = dim; + outputSize2->data[0] = dim; + context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1); + context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + int num_strings = tflite::GetStringCount(input); + TfLiteTensor* label = GetOutput(context, node, 0); + TfLiteTensor* weight = GetOutput(context, node, 1); + + std::map<int64, int> feature_id_counts; + for (int i = 0; i < num_strings; i++) { + // Use fingerprint of feature name as id. + auto strref = tflite::GetString(input, i); + if (!IsValidNgram(strref)) { + label->data.i32[i] = 0; + weight->data.i32[i] = 0; + continue; + } + + int64 feature_id = + ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension; + + label->data.i32[i] = static_cast<int32>(feature_id); + weight->data.f[i] = + std::count(strref.str, strref.str + strref.len, ' ') + 1; + } + // Explicitly set an empty result to make preceding ops run. + if (num_strings == 0) { + label->data.i32[0] = 0; + weight->data.i32[0] = 0; + } + return kTfLiteOk; +} + +} // namespace extract + +TfLiteRegistration* Register_EXTRACT_FEATURES() { + static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare, + extract::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc new file mode 100644 index 0000000000..9b8676bab6 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include <farmhash.h> + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_EXTRACT_FEATURES(); + +namespace { + +using ::testing::ElementsAre; + +class ExtractFeatureOpModel : public SingleOpModel { + public: + explicit ExtractFeatureOpModel(const std::vector<string>& input) { + input_ = AddInput(TensorType_STRING); + signature_ = AddOutput(TensorType_INT32); + weight_ = AddOutput(TensorType_FLOAT32); + + SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES); + BuildInterpreter({{static_cast<int>(input.size())}}); + PopulateStringTensor(input_, input); + } + + std::vector<int> GetSignature() { return ExtractVector<int>(signature_); } + std::vector<float> GetWeight() { return ExtractVector<float>(weight_); } + + private: + int input_; + int signature_; + int weight_; +}; + +int CalcFeature(const string& str) { + return ::util::Fingerprint64(str) % 1000000; +} + +TEST(ExtractFeatureOpTest, RegularInput) { + ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), + ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"), + CalcFeature("Hi !"), CalcFeature("!"), + CalcFeature("! <E>"), 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0)); +} + +TEST(ExtractFeatureOpTest, OneInput) { + ExtractFeatureOpModel m({"Hi"}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi"))); + EXPECT_THAT(m.GetWeight(), ElementsAre(1)); +} + +TEST(ExtractFeatureOpTest, ZeroInput) { + ExtractFeatureOpModel m({}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0)); +} + +TEST(ExtractFeatureOpTest, AllBlacklistInput) { + ExtractFeatureOpModel m({"<S>", "<E>"}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0)); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc new file mode 100644 index 0000000000..d0dc2a35a7 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Normalize the string input. +// +// Input: +// Input[0]: One sentence. string[1] +// +// Output: +// Output[0]: Normalized sentence. string[1] +// +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace normalize { + +// Predictor transforms. +const char kPunctuationsRegex[] = "[.*()\"]"; + +const std::map<string, string>* kRegexTransforms = + new std::map<string, string>({ + {"([^\\s]+)n't", "\\1 not"}, + {"([^\\s]+)'nt", "\\1 not"}, + {"([^\\s]+)'ll", "\\1 will"}, + {"([^\\s]+)'re", "\\1 are"}, + {"([^\\s]+)'ve", "\\1 have"}, + {"i'm", "i am"}, + }); + +static const char kStartToken[] = "<S>"; +static const char kEndToken[] = "<E>"; +static const int32 kMaxInputChars = 300; + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0); + + string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len))); + absl::StripAsciiWhitespace(&result); + // Do not remove commas, semi-colons or colons from the sentences as they can + // indicate the beginning of a new clause. + RE2::GlobalReplace(&result, kPunctuationsRegex, ""); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", + "\\1\\2"); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1"); + for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); + iter++) { + RE2::GlobalReplace(&result, iter->first, iter->second); + } + + // Treat questions & interjections as special cases. + RE2::GlobalReplace(&result, "([?])+", "\\1"); + RE2::GlobalReplace(&result, "([!])+", "\\1"); + RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 "); + RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2"); + + RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", ""); + RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", ""); + absl::StripAsciiWhitespace(&result); + + // Add start and end token. + // Truncate input to maximum allowed size. + if (result.length() <= kMaxInputChars) { + absl::StrAppend(&result, " ", kEndToken); + } else { + result = result.substr(0, kMaxInputChars); + } + result = absl::StrCat(kStartToken, " ", result); + + tflite::DynamicBuffer buf; + buf.AddString(result.data(), result.length()); + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} + +} // namespace normalize + +TfLiteRegistration* Register_NORMALIZE() { + static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc new file mode 100644 index 0000000000..4d35dba9a6 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_NORMALIZE(); + +namespace { + +using ::testing::ElementsAreArray; + +class NormalizeOpModel : public SingleOpModel { + public: + explicit NormalizeOpModel(const string& input) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetCustomOp("Normalize", {}, Register_NORMALIZE); + BuildInterpreter({{static_cast<int>(input.size())}}); + PopulateStringTensor(input_, {input}); + } + + std::vector<string> GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector<string> result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + private: + int input_; + int output_; +}; + +TEST(NormalizeOpTest, RegularInput) { + NormalizeOpModel m("I'm good; you're welcome"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), + ElementsAreArray({"<S> i am good; you are welcome <E>"})); +} + +TEST(NormalizeOpTest, OneInput) { + NormalizeOpModel m("Hi!!!!"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"})); +} + +TEST(NormalizeOpTest, EmptyInput) { + NormalizeOpModel m(""); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"})); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc new file mode 100644 index 0000000000..7b23adb990 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Lookup projected hash signatures in Predictor model, +// output predicted labels and weights in decreasing order. +// +// Input: +// Input[0]: A list of hash signatures. int32[num of input] +// Input[1]: Hash signature keys in the model. int32[keys of model] +// Input[2]: Labels in the model. int32[keys of model, item per entry] +// Input[3]: Weights in the model. float[keys of model, item per entry] +// +// Output: +// Output[0]: Predicted labels. int32[num of output] +// Output[1]: Predicted weights. float[num of output] +// + +#include <algorithm> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace predict { + +struct PredictOption { + int32_t num_output; + float weight_threshold; + + static PredictOption* Cast(void* ptr) { + return reinterpret_cast<PredictOption*>(ptr); + } +}; + +bool WeightGreater(const std::pair<int32_t, float>& a, + const std::pair<int32_t, float>& b) { + return a.second > b.second; +} + +void* Init(TfLiteContext* context, const char* custom_option, size_t length) { + if (custom_option == nullptr || length != sizeof(PredictOption)) { + fprintf(stderr, "No Custom option set\n"); + exit(1); + } + PredictOption* option = new PredictOption; + int offset = 0; + option->num_output = + *reinterpret_cast<const int32_t*>(custom_option + offset); + offset += sizeof(int32_t); + option->weight_threshold = + *reinterpret_cast<const float*>(custom_option + offset); + return reinterpret_cast<void*>(option); +} + +void Free(TfLiteContext* context, void* buffer) { + delete PredictOption::Cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_label->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_weight->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_label->dims->data[1], + model_weight->dims->data[1]); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32); + + TfLiteIntArray* label_size = TfLiteIntArrayCreate(1); + label_size->data[0] = option->num_output; + TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1); + weight_size->data[0] = option->num_output; + TfLiteStatus status = + context->ResizeTensor(context, output_label, label_size); + if (status != kTfLiteOk) { + return status; + } + return context->ResizeTensor(context, output_weight, weight_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + + // Aggregate by key + std::unordered_map<int32_t, float> aggregation; + const int num_input = lookup->dims->data[0]; + const int num_rows = model_key->dims->data[0]; + const int items = model_label->dims->data[1]; + int* model_key_end = model_key->data.i32 + num_rows; + + for (int i = 0; i < num_input; i++) { + int* ptr = std::lower_bound(model_key->data.i32, model_key_end, + lookup->data.i32[i]); + if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) { + int idx = ptr - model_key->data.i32; + for (int j = 0; j < items; j++) { + aggregation[model_label->data.i32[idx * items + j]] += + model_weight->data.f[idx * items + j] / num_input; + } + } + } + + // Sort by value + std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(), + aggregation.end()); + std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + for (int i = 0; i < output_label->dims->data[0]; i++) { + if (i >= sorted_labels.size() || + sorted_labels[i].second < option->weight_threshold) { + // Set -1 to avoid lookup message with id 0, which is set for backoff. + output_label->data.i32[i] = -1; + output_weight->data.f[i] = 0.0f; + } else { + output_label->data.i32[i] = sorted_labels[i].first; + output_weight->data.f[i] = sorted_labels[i].second; + } + } + + return kTfLiteOk; +} + +} // namespace predict + +TfLiteRegistration* Register_PREDICT() { + static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare, + predict::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc new file mode 100644 index 0000000000..e97c58cbd1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_PREDICT(); + +namespace { + +using ::testing::ElementsAreArray; + +class PredictOpModel : public SingleOpModel { + public: + PredictOpModel(std::initializer_list<int> input_signature_shape, + std::initializer_list<int> key_shape, + std::initializer_list<int> labelweight_shape, int num_output, + float threshold) { + input_signature_ = AddInput(TensorType_INT32); + model_key_ = AddInput(TensorType_INT32); + model_label_ = AddInput(TensorType_INT32); + model_weight_ = AddInput(TensorType_FLOAT32); + output_label_ = AddOutput(TensorType_INT32); + output_weight_ = AddOutput(TensorType_FLOAT32); + + std::vector<uint8_t> predict_option; + writeInt32(num_output, &predict_option); + writeFloat32(threshold, &predict_option); + SetCustomOp("Predict", predict_option, Register_PREDICT); + BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape, + labelweight_shape}}); + } + + void SetInputSignature(std::initializer_list<int> data) { + PopulateTensor<int>(input_signature_, data); + } + + void SetModelKey(std::initializer_list<int> data) { + PopulateTensor<int>(model_key_, data); + } + + void SetModelLabel(std::initializer_list<int> data) { + PopulateTensor<int>(model_label_, data); + } + + void SetModelWeight(std::initializer_list<float> data) { + PopulateTensor<float>(model_weight_, data); + } + + std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); } + std::vector<float> GetWeight() { + return ExtractVector<float>(output_weight_); + } + + void writeFloat32(float value, std::vector<uint8_t>* data) { + union { + float v; + uint8_t r[4]; + } float_to_raw; + float_to_raw.v = value; + for (unsigned char i : float_to_raw.r) { + data->push_back(i); + } + } + + void writeInt32(int32_t value, std::vector<uint8_t>* data) { + union { + int32_t v; + uint8_t r[4]; + } int32_to_raw; + int32_to_raw.v = value; + for (unsigned char i : int32_to_raw.r) { + data->push_back(i); + } + } + + private: + int input_signature_; + int model_key_; + int model_label_; + int model_weight_; + int output_label_; + int output_weight_; +}; + +TEST(PredictOpTest, AllLabelsAreValid) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05}))); +} + +TEST(PredictOpTest, MoreLabelsThanRequired) { + PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1}))); +} + +TEST(PredictOpTest, OneLabelDoesNotPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0}))); +} + +TEST(PredictOpTest, NoneLabelPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +TEST(PredictOpTest, OnlyOneLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0}))); +} + +TEST(PredictOpTest, NoLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({5, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc new file mode 100644 index 0000000000..a28222213e --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include "absl/strings/str_split.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace custom { +namespace smartreply { + +// Split sentence into segments (using punctuation). +std::vector<string> SplitSentence(const string& input) { + string result(input); + + RE2::GlobalReplace(&result, "([?.!,])+", " \\1"); + RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t"); + RE2::GlobalReplace(&result, "[ ]+", " "); + RE2::GlobalReplace(&result, "\t+$", ""); + + return strings::Split(result, '\t'); +} + +// Predict with TfLite model. +void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter, + std::map<string, float>* response_map) { + { + TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]); + tflite::DynamicBuffer buf; + buf.AddString(sentence.data(), sentence.length()); + buf.WriteToTensor(input); + interpreter->AllocateTensors(); + + interpreter->Invoke(); + + TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]); + TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]); + + for (int i = 0; i < confidence->dims->data[0]; i++) { + float weight = confidence->data.f[i]; + auto response_text = tflite::GetString(messages, i); + if (response_text.len > 0) { + (*response_map)[string(response_text.str, response_text.len)] += weight; + } + } + } +} + +void GetSegmentPredictions( + const std::vector<string>& input, const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector<PredictorResponse>* predictor_responses) { + // Initialize interpreter + std::unique_ptr<::tflite::Interpreter> interpreter; + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + ::tflite::InterpreterBuilder(model, resolver)(&interpreter); + + if (!model.initialized()) { + fprintf(stderr, "Failed to mmap model \n"); + return; + } + + // Execute Tflite Model + std::map<string, float> response_map; + std::vector<string> sentences; + for (const string& str : input) { + std::vector<string> splitted_str = SplitSentence(str); + sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end()); + } + for (const auto& sentence : sentences) { + ExecuteTfLite(sentence, interpreter.get(), &response_map); + } + + // Generate the result. + for (const auto& iter : response_map) { + PredictorResponse prediction(iter.first, iter.second); + predictor_responses->emplace_back(prediction); + } + std::sort(predictor_responses->begin(), predictor_responses->end(), + [](const PredictorResponse& a, const PredictorResponse& b) { + return a.GetScore() > b.GetScore(); + }); + + // Add backoff response. + for (const string& backoff : config.backoff_responses) { + if (predictor_responses->size() >= config.num_response) { + break; + } + predictor_responses->push_back({backoff, config.backoff_confidence}); + } +} + +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h new file mode 100644 index 0000000000..3b9a2b32e1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ + +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace custom { +namespace smartreply { + +const int kDefaultNumResponse = 10; +const float kDefaultBackoffConfidence = 1e-4; + +class PredictorResponse; +struct SmartReplyConfig; + +// With a given string as input, predict the response with a Tflite model. +// When config.backoff_response is not empty, predictor_responses will be filled +// with messagees from backoff response. +void GetSegmentPredictions(const std::vector<string>& input, + const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector<PredictorResponse>* predictor_responses); + +// Data object used to hold a single predictor response. +// It includes messages, and confidence. +class PredictorResponse { + public: + PredictorResponse(const string& response_text, float score) { + response_text_ = response_text; + prediction_score_ = score; + } + + // Accessor methods. + const string& GetText() const { return response_text_; } + float GetScore() const { return prediction_score_; } + + private: + string response_text_ = ""; + float prediction_score_ = 0.0; +}; + +// Configurations for SmartReply. +struct SmartReplyConfig { + // Maximum responses to return. + int num_response; + // Default confidence for backoff responses. + float backoff_confidence; + // Backoff responses are used when predicted responses cannot fulfill the + // list. + const std::vector<string>& backoff_responses; + + SmartReplyConfig(std::vector<string> backoff_responses) + : num_response(kDefaultNumResponse), + backoff_confidence(kDefaultBackoffConfidence), + backoff_responses(backoff_responses) {} +}; + +} // namespace smartreply +} // namespace custom +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc new file mode 100644 index 0000000000..2fa9923bc9 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include <fstream> +#include <unordered_set> + +#include "base/logging.h" +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace custom { +namespace smartreply { +namespace { + +const char kModelName[] = "smartreply_ondevice_model.bin"; +const char kSamples[] = "smartreply_samples.tsv"; + +MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { + bool has_expected_response = false; + for (const auto &item : *arg) { + const string &response = item.GetText(); + if (expected_response.find(response) != expected_response.end()) { + has_expected_response = true; + break; + } + } + return has_expected_response; +} + +class PredictorTest : public ::testing::Test { + protected: + PredictorTest() { + model_ = tflite::FlatBufferModel::BuildFromFile( + StrCat(TestDataPath(), "/", kModelName).c_str()); + CHECK(model_); + } + ~PredictorTest() override {} + + std::unique_ptr<::tflite::FlatBufferModel> model_; +}; + +TEST_F(PredictorTest, GetSegmentPredictions) { + std::vector<PredictorResponse> predictions; + + GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT( + &predictions, + IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"}))); +} + +TEST_F(PredictorTest, TestTwoSentences) { + std::vector<PredictorResponse> predictions; + + GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}}, + &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>( + {"Hi, how are you doing?"}))); +} + +TEST_F(PredictorTest, TestBackoff) { + std::vector<PredictorResponse> predictions; + + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_EQ(predictions.size(), 0); + + // Backoff responses are returned in order. + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}}, + &predictions); + EXPECT_EQ(predictions.size(), 2); + EXPECT_EQ(predictions[0].GetText(), "Yes"); + EXPECT_EQ(predictions[1].GetText(), "Ok"); +} + +TEST_F(PredictorTest, BatchTest) { + int total_items = 0; + int total_responses = 0; + int total_triggers = 0; + + string line; + std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + while (std::getline(fin, line)) { + const std::vector<string> &fields = strings::Split(line, '\t'); + if (fields.empty()) { + continue; + } + + // Parse sample file and predict + const string &msg = fields[0]; + std::vector<PredictorResponse> predictions; + GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions); + + // Validate response and generate stats. + total_items++; + total_responses += predictions.size(); + if (!predictions.empty()) { + total_triggers++; + } + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>( + fields.begin() + 1, fields.end()))); + } + + LOG(INFO) << "Responses: " << total_responses << " / " << total_items; + LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items; + EXPECT_EQ(total_triggers, total_items); +} + +} // namespace +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc new file mode 100644 index 0000000000..f5d1f436bc --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech Hotword model using TFLite Ops. + +#include <string.h> + +#include <memory> +#include <string> + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +void RunTest(int model_input_tensor, int svdf_layer_state_tensor, + int model_output_tensor, const string& model_name, + const string& golden_in_name, const string& golden_out_name) { + // Read the model. + string tflite_file_path = file::JoinPath(TestDataPath(), model_name); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr<Interpreter> interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Reset the SVDF layer state. + memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0, + interpreter->tensor(svdf_layer_state_tensor)->bytes); + + // Load the input frames. + Frames input_frames; + const string input_file_path = file::JoinPath(TestDataPath(), golden_in_name); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), golden_out_name); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(model_input_tensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(model_input_tensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(model_output_tensor)->dims->data[1]; + const int input_sequence_size = + input_frames[0].size() / (speech_input_size * speech_batch_size); + float* input_ptr = interpreter->tensor(model_input_tensor)->data.f; + float* output_ptr = interpreter->tensor(model_output_tensor)->data.f; + + // The first layer (SVDF) input size is 40 (speech_input_size). Each speech + // input frames for this model is 1280 floats, which can be fed to input in a + // sequence of size 32 (input_sequence_size). + for (int i = 0; i < TestInputSize(input_frames); i++) { + int frame_ptr = 0; + for (int s = 0; s < input_sequence_size; s++) { + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + interpreter->Invoke(); + } + // After the whole frame (1280 floats) is fed, we can check the output frame + // matches with the golden output frame. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +TEST(SpeechHotword, OkGoogleTestRank1) { + constexpr int kModelInputTensor = 0; + constexpr int kSvdfLayerStateTensor = 4; + constexpr int kModelOutputTensor = 18; + + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank1.csv"); +} + +TEST(SpeechHotword, OkGoogleTestRank2) { + constexpr int kModelInputTensor = 17; + constexpr int kSvdfLayerStateTensor = 1; + constexpr int kModelOutputTensor = 18; + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank2.csv"); +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc new file mode 100644 index 0000000000..687cfab0b2 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech SpeakerId model using TFLite Ops. + +#include <string.h> + +#include <memory> +#include <string> + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kModelOutputTensor = 66; + +TEST(SpeechSpeakerId, OkGoogleTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + std::unique_ptr<Interpreter> interpreter; + InterpreterBuilder(*model, resolver)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc new file mode 100644 index 0000000000..30d89a1354 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TERSE AM model using TFLite Ops. + +#include <string.h> + +#include <memory> +#include <string> + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kLstmLayer4OutputStateTensor = 82; +constexpr int kLstmLayer4CellStateTensor = 83; +constexpr int kLstmLayer5OutputStateTensor = 103; +constexpr int kLstmLayer5CellStateTensor = 104; +constexpr int kModelOutputTensor = 109; + +TEST(SpeechTerseAm, RandomIOTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr<Interpreter> interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5CellStateTensor)->bytes); + + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc new file mode 100644 index 0000000000..e6f2673a42 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TTS model using TFLite Ops. + +#include <string.h> + +#include <memory> +#include <string> + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 25; +constexpr int kLstmLayer1CellStateTensor = 26; +constexpr int kLstmLayer2OutputStateTensor = 46; +constexpr int kLstmLayer2CellStateTensor = 47; +constexpr int kLstmLayer3OutputStateTensor = 67; +constexpr int kLstmLayer3CellStateTensor = 68; +constexpr int kRnnLayerHiddenStateTensor = 73; +constexpr int kModelOutputTensor = 74; + +TEST(SpeechTTS, RandomIOTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr<Interpreter> interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0, + interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes); + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h new file mode 100644 index 0000000000..b2596babd0 --- /dev/null +++ b/tensorflow/contrib/lite/models/test_utils.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ + +#include <stdlib.h> +#include <string.h> + +#include <fstream> +#include <memory> +#include <string> +#include <vector> + +namespace tflite { +namespace models { +using Frames = std::vector<std::vector<float>>; +} // namespace models +} // namespace tflite + +#ifndef __ANDROID__ +#include "file/base/path.h" +#include "tensorflow/core/platform/test.h" + +inline string TestDataPath() { + return string(file::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + "contrib/lite/models/testdata/")); +} +inline int TestInputSize(const tflite::models::Frames& input_frames) { + return input_frames.size(); +} +#else +inline string TestDataPath() { + return string("third_party/tensorflow/contrib/lite/models/testdata/"); +} + +inline int TestInputSize(const tflite::models::Frames& input_frames) { + // Android TAP is very slow, we only test the first 20 frames. + return 20; +} +#endif + +namespace tflite { +namespace models { + +// Read float data from a comma-separated file: +// Each line will be read into a float vector. +// The return result will be a vector of float vectors. +void ReadFrames(const string& csv_file_path, Frames* frames) { + std::ifstream csv_file(csv_file_path); + string line; + while (std::getline(csv_file, line, '\n')) { + std::vector<float> fields; + // Used by strtok_r internaly for successive calls on the same string. + char* save_ptr = nullptr; + + // Tokenize the line. + char* next_token = + strtok_r(const_cast<char*>(line.c_str()), ",", &save_ptr); + while (next_token != nullptr) { + float f = strtod(next_token, nullptr); + fields.push_back(f); + next_token = strtok_r(nullptr, ",", &save_ptr); + } + frames->push_back(fields); + } + csv_file.close(); +} + +} // namespace models +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ |