diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-30 23:58:26 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-01 00:02:04 -0800 |
commit | 6b6244c40197b34f49bb50aa52efb082380d4637 (patch) | |
tree | ce50b9f28330c7ad194b27263f2534221f176457 | |
parent | 370e521762f3cbd558a7e56992e3b062236b626f (diff) |
Build demo app for SmartReply
PiperOrigin-RevId: 177559103
21 files changed, 758 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index e3c9cdd99b..5813b3de4d 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -223,11 +223,12 @@ def gen_selected_ops(name, model): """ out = name + "_registration.cc" tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" native.genrule( name = name, srcs = [model], outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") - % (tool, model, out), + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), tools = [tool], ) diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index fbdf19f205..733c3f4c7f 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -1,7 +1,92 @@ package(default_visibility = ["//visibility:public"]) +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + licenses(["notice"]) # Apache 2.0 +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..75ed9432c8 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + Copyright 2017 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.example.android.smartreply" > + + <uses-sdk + android:minSdkVersion="15" + android:targetSdkVersion="24" /> + + <application android:label="TfLite SmartReply Demo"> + <activity + android:name="com.example.android.smartreply.MainActivity" + android:configChanges="orientation|keyboardHidden|screenSize" + android:windowSoftInputMode="stateUnchanged|adjustPan" + android:label="TfLite SmartReply Demo" + android:screenOrientation="portrait" > + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + +</manifest> diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000..f8767b443a --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000..3c882ffc43 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000..a0a5b46b5f --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000..02fec9ae5e --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000..3357fd17c1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + * <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000..d5b1ac0ffb --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List<String> labelList = new ArrayList<String>(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000..23b4cadc00 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:tools="http://schemas.android.com/tools" + android:layout_width="match_parent" + android:layout_height="match_parent" + android:orientation="vertical"> + + <LinearLayout + android:layout_width="fill_parent" + android:layout_height="0dp" + android:padding="5dip" + android:layout_weight="3"> + + <TextView + android:id="@+id/message_text" + android:layout_width="fill_parent" + android:layout_height="fill_parent" + android:scrollbars="vertical" + android:gravity="bottom"/> + </LinearLayout> + + <LinearLayout + android:layout_width="fill_parent" + android:layout_height="0dp" + android:padding="5dip" + android:layout_weight="1"> + + <EditText + android:id="@+id/message_input" + android:layout_width="0dp" + android:layout_height="fill_parent" + android:layout_weight="6" + android:scrollbars="vertical" + android:hint="Enter Text" + android:gravity="top" + android:inputType="text"/> + <Button + android:id="@+id/send_button" + android:layout_width="0dp" + android:layout_height="fill_parent" + android:layout_weight="2" + android:text="Send" /> + </LinearLayout> + +</LinearLayout> diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc new file mode 100644 index 0000000000..f158cc511a --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> +#include <utility> +#include <vector> + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +const char kIllegalStateException[] = "java/lang/IllegalStateException"; + +using tflite::custom::smartreply::GetSegmentPredictions; +using tflite::custom::smartreply::PredictorResponse; + +template <typename T> +T CheckNotNull(JNIEnv* env, T&& t) { + if (t == nullptr) { + env->ThrowNew(env->FindClass(kIllegalStateException), ""); + return nullptr; + } + return std::forward<T>(t); +} + +std::vector<std::string> jniStringArrayToVector(JNIEnv* env, + jobjectArray string_array) { + int count = env->GetArrayLength(string_array); + std::vector<std::string> result; + for (int i = 0; i < count; i++) { + auto jstr = + reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i)); + const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE); + result.emplace_back(std::string(raw_str)); + env->ReleaseStringUTFChars(jstr, raw_str); + } + return result; +} + +struct JNIStorage { + std::vector<std::string> backoff_list; + std::unique_ptr<::tflite::FlatBufferModel> model; +}; + +extern "C" JNIEXPORT jlong JNICALL +Java_com_example_android_smartreply_SmartReplyClient_loadJNI( + JNIEnv* env, jobject thiz, jobject model_buffer, + jobjectArray backoff_list) { + const char* buf = + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + + JNIStorage* storage = new JNIStorage; + storage->model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast<size_t>(capacity)); + storage->backoff_list = jniStringArrayToVector(env, backoff_list); + + if (!storage->model) { + delete storage; + env->ThrowNew(env->FindClass(kIllegalStateException), ""); + return 0; + } + return reinterpret_cast<jlong>(storage); +} + +extern "C" JNIEXPORT jobjectArray JNICALL +Java_com_example_android_smartreply_SmartReplyClient_predictJNI( + JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) { + // Predict + if (storage_ptr == 0) { + return nullptr; + } + JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr); + if (storage == nullptr) { + return nullptr; + } + std::vector<PredictorResponse> responses; + GetSegmentPredictions(jniStringArrayToVector(env, input_text), + *storage->model, {storage->backoff_list}, &responses); + + // Create a SmartReply[] to return back to Java + jclass smart_reply_class = CheckNotNull( + env, env->FindClass("com/example/android/smartreply/SmartReply")); + if (env->ExceptionCheck()) { + return nullptr; + } + jmethodID smart_reply_ctor = CheckNotNull( + env, + env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V")); + if (env->ExceptionCheck()) { + return nullptr; + } + jobjectArray array = CheckNotNull( + env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr)); + if (env->ExceptionCheck()) { + return nullptr; + } + for (int i = 0; i < responses.size(); i++) { + jstring text = + CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data())); + if (env->ExceptionCheck()) { + return nullptr; + } + jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text, + responses[i].GetScore()); + env->SetObjectArrayElement(array, i, reply); + } + return array; +} + +extern "C" JNIEXPORT void JNICALL +Java_com_example_android_smartreply_SmartReplyClient_unloadJNI( + JNIEnv* env, jobject thiz, jlong storage_ptr) { + if (storage_ptr != 0) { + JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr); + delete storage; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc index 1c422b659a..f97a6486d6 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc @@ -23,7 +23,7 @@ limitations under the License. #include <algorithm> #include <map> -#include "re2/re2.h" + #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/string_util.h" @@ -81,7 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* label = GetOutput(context, node, 0); TfLiteTensor* weight = GetOutput(context, node, 1); - std::map<int64, int> feature_id_counts; + std::map<int64_t, int> feature_id_counts; for (int i = 0; i < num_strings; i++) { // Use fingerprint of feature name as id. auto strref = tflite::GetString(input, i); @@ -91,10 +91,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { continue; } - int64 feature_id = + int64_t feature_id = ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension; - - label->data.i32[i] = static_cast<int32>(feature_id); + label->data.i32[i] = static_cast<int32_t>(feature_id); weight->data.f[i] = std::count(strref.str, strref.str + strref.len, ' ') + 1; } diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc index d0dc2a35a7..c55ac9f52f 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc @@ -21,7 +21,10 @@ limitations under the License. // Output: // Output[0]: Normalized sentence. string[1] // -#include "absl/strings/ascii.h" + +#include <algorithm> +#include <string> + #include "absl/strings/str_cat.h" #include "absl/strings/strip.h" #include "re2/re2.h" @@ -50,7 +53,7 @@ const std::map<string, string>* kRegexTransforms = static const char kStartToken[] = "<S>"; static const char kEndToken[] = "<E>"; -static const int32 kMaxInputChars = 300; +static const int32_t kMaxInputChars = 300; TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0); diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc index a28222213e..6da5cc8eec 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -30,7 +30,7 @@ namespace custom { namespace smartreply { // Split sentence into segments (using punctuation). -std::vector<string> SplitSentence(const string& input) { +std::vector<std::string> SplitSentence(const std::string& input) { string result(input); RE2::GlobalReplace(&result, "([?.!,])+", " \\1"); @@ -38,12 +38,13 @@ std::vector<string> SplitSentence(const string& input) { RE2::GlobalReplace(&result, "[ ]+", " "); RE2::GlobalReplace(&result, "\t+$", ""); - return strings::Split(result, '\t'); + return absl::StrSplit(result, '\t'); } // Predict with TfLite model. -void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter, - std::map<string, float>* response_map) { +void ExecuteTfLite(const std::string& sentence, + ::tflite::Interpreter* interpreter, + std::map<std::string, float>* response_map) { { TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]); tflite::DynamicBuffer buf; @@ -67,8 +68,8 @@ void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter, } void GetSegmentPredictions( - const std::vector<string>& input, const ::tflite::FlatBufferModel& model, - const SmartReplyConfig& config, + const std::vector<std::string>& input, + const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config, std::vector<PredictorResponse>* predictor_responses) { // Initialize interpreter std::unique_ptr<::tflite::Interpreter> interpreter; @@ -82,10 +83,10 @@ void GetSegmentPredictions( } // Execute Tflite Model - std::map<string, float> response_map; - std::vector<string> sentences; - for (const string& str : input) { - std::vector<string> splitted_str = SplitSentence(str); + std::map<std::string, float> response_map; + std::vector<std::string> sentences; + for (const std::string& str : input) { + std::vector<std::string> splitted_str = SplitSentence(str); sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end()); } for (const auto& sentence : sentences) { diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h index 3b9a2b32e1..d17323a3f9 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -34,7 +34,7 @@ struct SmartReplyConfig; // With a given string as input, predict the response with a Tflite model. // When config.backoff_response is not empty, predictor_responses will be filled // with messagees from backoff response. -void GetSegmentPredictions(const std::vector<string>& input, +void GetSegmentPredictions(const std::vector<std::string>& input, const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config, std::vector<PredictorResponse>* predictor_responses); @@ -43,17 +43,17 @@ void GetSegmentPredictions(const std::vector<string>& input, // It includes messages, and confidence. class PredictorResponse { public: - PredictorResponse(const string& response_text, float score) { + PredictorResponse(const std::string& response_text, float score) { response_text_ = response_text; prediction_score_ = score; } // Accessor methods. - const string& GetText() const { return response_text_; } + const std::string& GetText() const { return response_text_; } float GetScore() const { return prediction_score_; } private: - string response_text_ = ""; + std::string response_text_ = ""; float prediction_score_ = 0.0; }; @@ -65,9 +65,9 @@ struct SmartReplyConfig { float backoff_confidence; // Backoff responses are used when predicted responses cannot fulfill the // list. - const std::vector<string>& backoff_responses; + const std::vector<std::string>& backoff_responses; - SmartReplyConfig(std::vector<string> backoff_responses) + SmartReplyConfig(std::vector<std::string> backoff_responses) : num_response(kDefaultNumResponse), backoff_confidence(kDefaultBackoffConfidence), backoff_responses(backoff_responses) {} diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index 2fa9923bc9..97d3c650e2 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include <fstream> #include <unordered_set> -#include "base/logging.h" #include <gmock/gmock.h> #include <gtest/gtest.h> #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "tensorflow/contrib/lite/models/test_utils.h" +#include "tensorflow/contrib/lite/string_util.h" namespace tflite { namespace custom { @@ -65,7 +65,6 @@ TEST_F(PredictorTest, GetSegmentPredictions) { float max = 0; for (const auto &item : predictions) { - LOG(INFO) << "Response: " << item.GetText(); if (item.GetScore() > max) { max = item.GetScore(); } @@ -86,7 +85,6 @@ TEST_F(PredictorTest, TestTwoSentences) { float max = 0; for (const auto &item : predictions) { - LOG(INFO) << "Response: " << item.GetText(); if (item.GetScore() > max) { max = item.GetScore(); } @@ -119,7 +117,7 @@ TEST_F(PredictorTest, BatchTest) { string line; std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); while (std::getline(fin, line)) { - const std::vector<string> &fields = strings::Split(line, '\t'); + const std::vector<string> fields = absl::StrSplit(line, '\t'); if (fields.empty()) { continue; } @@ -139,9 +137,8 @@ TEST_F(PredictorTest, BatchTest) { fields.begin() + 1, fields.end()))); } - LOG(INFO) << "Responses: " << total_responses << " / " << total_items; - LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items; EXPECT_EQ(total_triggers, total_items); + EXPECT_GE(total_responses, total_triggers); } } // namespace diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 21b32d8434..751682215b 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -13,6 +13,7 @@ tf_cc_binary( "//tensorflow/contrib/lite/tools:gen_op_registration", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc index 1b28b8bcd9..17b514c916 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc +++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc @@ -13,30 +13,50 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <cassert> #include <fstream> +#include <map> #include <sstream> #include <string> #include <vector> +#include "absl/strings/strip.h" #include "tensorflow/contrib/lite/tools/gen_op_registration.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" +const char kInputModelFlag[] = "input_model"; +const char kOutputRegistrationFlag[] = "output_registration"; +const char kTfLitePathFlag[] = "tflite_path"; + using tensorflow::Flag; using tensorflow::Flags; using tensorflow::string; +void ParseFlagAndInit(int argc, char** argv, string* input_model, + string* output_registration, string* tflite_path) { + std::vector<tensorflow::Flag> flag_list = { + Flag(kInputModelFlag, input_model, "path to the tflite model"), + Flag(kOutputRegistrationFlag, output_registration, + "filename for generated registration code"), + Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"), + }; + + Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); +} + namespace { -void GenerateFileContent(const string& filename, +void GenerateFileContent(const std::string& tflite_path, + const std::string& filename, const std::vector<string>& builtin_ops, const std::vector<string>& custom_ops) { std::ofstream fout(filename); - fout << "#include " - "\"third_party/tensorflow/contrib/lite/model.h\"\n"; - fout << "#include " - "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n"; + fout << "#include \"" << tflite_path << "/model.h\"\n"; + fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n"; + fout << "namespace tflite {\n"; fout << "namespace ops {\n"; if (!builtin_ops.empty()) { @@ -78,22 +98,20 @@ void GenerateFileContent(const string& filename, int main(int argc, char** argv) { string input_model; string output_registration; - std::vector<tensorflow::Flag> flag_list = { - Flag("input_model", &input_model, "path to the tflite model"), - Flag("output_registration", &output_registration, - "filename for generated registration code"), - }; - Flags::Parse(&argc, argv, flag_list); + string tflite_path; + ParseFlagAndInit(argc, argv, &input_model, &output_registration, + &tflite_path); - tensorflow::port::InitMain(argv[0], &argc, &argv); std::vector<string> builtin_ops; std::vector<string> custom_ops; - std::ifstream fin(input_model); std::stringstream content; content << fin.rdbuf(); - const ::tflite::Model* model = ::tflite::GetModel(content.str().data()); + // Need to store content data first, otherwise, it won't work in bazel. + string content_str = content.str(); + const ::tflite::Model* model = ::tflite::GetModel(content_str.data()); ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops); - GenerateFileContent(output_registration, builtin_ops, custom_ops); + GenerateFileContent(tflite_path, output_registration, builtin_ops, + custom_ops); return 0; } diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h index be60cf476d..906553da57 100644 --- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h @@ -46,7 +46,7 @@ class MutableOpResolver : public OpResolver { void AddCustom(const char* name, TfLiteRegistration* registration); private: - std::map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_; + std::map<int, TfLiteRegistration*> builtins_; std::map<std::string, TfLiteRegistration*> custom_ops_; }; diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 25e036e24c..11f9aa2259 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -207,11 +207,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_googlesource_code_re2", urls = [ - "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", - "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", + "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", + "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", + ], - sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f", - strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c", + sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b", + strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857", ) native.http_archive( @@ -800,3 +801,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", ], ) + + native.new_http_archive( + name = "tflite_smartreply", + build_file = str(Label("//third_party:tflite_smartreply.BUILD")), + sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c", + urls = [ + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip" + ], + ) diff --git a/third_party/tflite_smartreply.BUILD b/third_party/tflite_smartreply.BUILD new file mode 100644 index 0000000000..75663eff48 --- /dev/null +++ b/third_party/tflite_smartreply.BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "model_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) |