aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 23:58:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 00:02:04 -0800
commit6b6244c40197b34f49bb50aa52efb082380d4637 (patch)
treece50b9f28330c7ad194b27263f2534221f176457
parent370e521762f3cbd558a7e56992e3b062236b626f (diff)
Build demo app for SmartReply
PiperOrigin-RevId: 177559103
-rw-r--r--tensorflow/contrib/lite/build_def.bzl5
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD85
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml38
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD65
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD15
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt16
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java99
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java44
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java129
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml44
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc129
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc9
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize.cc7
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.cc21
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h12
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc9
-rw-r--r--tensorflow/contrib/lite/tools/BUILD1
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_main.cc48
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h2
-rw-r--r--tensorflow/workspace.bzl18
-rw-r--r--third_party/tflite_smartreply.BUILD13
21 files changed, 758 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index e3c9cdd99b..5813b3de4d 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -223,11 +223,12 @@ def gen_selected_ops(name, model):
"""
out = name + "_registration.cc"
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
+ tflite_path = "//tensorflow/contrib/lite"
native.genrule(
name = name,
srcs = [model],
outs = [out],
- cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)")
- % (tool, model, out),
+ cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
+ % (tool, model, out, tflite_path[2:]),
tools = [tool],
)
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
index fbdf19f205..733c3f4c7f 100644
--- a/tensorflow/contrib/lite/models/smartreply/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -1,7 +1,92 @@
package(default_visibility = ["//visibility:public"])
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
+
licenses(["notice"]) # Apache 2.0
+gen_selected_ops(
+ name = "smartreply_ops",
+ model = "@tflite_smartreply//:smartreply.tflite",
+)
+
+cc_library(
+ name = "custom_ops",
+ srcs = [
+ "ops/extract_feature.cc",
+ "ops/normalize.cc",
+ "ops/predict.cc",
+ ":smartreply_ops",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/tools:mutable_op_resolver",
+ "@com_google_absl//absl/strings",
+ "@com_googlesource_code_re2//:re2",
+ "@farmhash_archive//:farmhash",
+ ],
+)
+
+cc_library(
+ name = "predictor_lib",
+ srcs = ["predictor.cc"],
+ hdrs = ["predictor.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":custom_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/tools:mutable_op_resolver",
+ "@com_google_absl//absl/strings",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+cc_test(
+ name = "extract_feature_op_test",
+ size = "small",
+ srcs = ["ops/extract_feature_test.cc"],
+ deps = [
+ ":custom_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@farmhash_archive//:farmhash",
+ ],
+)
+
+cc_test(
+ name = "normalize_op_test",
+ size = "small",
+ srcs = ["ops/normalize_test.cc"],
+ deps = [
+ ":custom_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "predict_op_test",
+ size = "small",
+ srcs = ["ops/predict_test.cc"],
+ deps = [
+ ":custom_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000..75ed9432c8
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml
@@ -0,0 +1,38 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2017 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.example.android.smartreply" >
+
+ <uses-sdk
+ android:minSdkVersion="15"
+ android:targetSdkVersion="24" />
+
+ <application android:label="TfLite SmartReply Demo">
+ <activity
+ android:name="com.example.android.smartreply.MainActivity"
+ android:configChanges="orientation|keyboardHidden|screenSize"
+ android:windowSoftInputMode="stateUnchanged|adjustPan"
+ android:label="TfLite SmartReply Demo"
+ android:screenOrientation="portrait" >
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+ </application>
+
+</manifest>
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
new file mode 100644
index 0000000000..f8767b443a
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -0,0 +1,65 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite:build_def.bzl",
+ "tflite_copts",
+ "tflite_jni_binary",
+)
+
+filegroup(
+ name = "assets",
+ srcs = [
+ "@tflite_smartreply//:model_files",
+ ],
+)
+
+android_binary(
+ name = "SmartReplyDemo",
+ srcs = glob(["java/**/*.java"]),
+ assets = [":assets"],
+ assets_dir = "",
+ custom_package = "com.example.android.smartreply",
+ manifest = "AndroidManifest.xml",
+ nocompress_extensions = [
+ ".tflite",
+ ],
+ resource_files = glob(["res/**"]),
+ tags = ["manual"],
+ deps = [
+ ":smartreply_runtime",
+ "@androidsdk//com.android.support:support-v13-25.2.0",
+ "@androidsdk//com.android.support:support-v4-25.2.0",
+ ],
+)
+
+cc_library(
+ name = "smartreply_runtime",
+ srcs = ["libsmartreply_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
+tflite_jni_binary(
+ name = "libsmartreply_jni.so",
+ deps = [
+ ":smartreply_jni_lib",
+ ],
+)
+
+cc_library(
+ name = "smartreply_jni_lib",
+ srcs = [
+ "smartreply_jni.cc",
+ ],
+ copts = tflite_copts(),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/models/smartreply:predictor_lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD
new file mode 100644
index 0000000000..3c882ffc43
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD
@@ -0,0 +1,15 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(glob(["*"]))
+
+filegroup(
+ name = "assets_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+)
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt
new file mode 100644
index 0000000000..a0a5b46b5f
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt
@@ -0,0 +1,16 @@
+Ok
+Yes
+No
+👍
+☺
+😟
+❤️
+Lol
+Thanks
+Got it
+Done
+Nice
+I don't know
+What?
+Why?
+What's up?
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java
new file mode 100644
index 0000000000..02fec9ae5e
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java
@@ -0,0 +1,99 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.smartreply;
+
+import android.app.Activity;
+import android.os.Bundle;
+import android.os.Handler;
+import android.util.Log;
+import android.view.View;
+import android.widget.Button;
+import android.widget.EditText;
+import android.widget.TextView;
+
+/**
+ * The main (and only) activity of this demo app. Displays a text box which updates as messages are
+ * received.
+ */
+public class MainActivity extends Activity {
+ private static final String TAG = "SmartReplyDemo";
+ private SmartReplyClient client;
+
+ private Button sendButton;
+ private TextView messageTextView;
+ private EditText messageInput;
+
+ private Handler handler;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ Log.v(TAG, "onCreate");
+ setContentView(R.layout.main_activity);
+
+ client = new SmartReplyClient(getApplicationContext());
+ handler = new Handler();
+
+ sendButton = (Button) findViewById(R.id.send_button);
+ sendButton.setOnClickListener(
+ (View v) -> {
+ send(messageInput.getText().toString());
+ });
+
+ messageTextView = (TextView) findViewById(R.id.message_text);
+ messageInput = (EditText) findViewById(R.id.message_input);
+ }
+
+ @Override
+ protected void onStart() {
+ super.onStart();
+ Log.v(TAG, "onStart");
+ handler.post(
+ () -> {
+ client.loadModel();
+ });
+ }
+
+ @Override
+ protected void onStop() {
+ super.onStop();
+ Log.v(TAG, "onStop");
+ handler.post(
+ () -> {
+ client.unloadModel();
+ });
+ }
+
+ private void send(final String message) {
+ handler.post(
+ () -> {
+ messageTextView.append("Input: " + message + "\n");
+
+ SmartReply[] ans = client.predict(new String[] {message});
+ for (SmartReply reply : ans) {
+ appendMessage("Reply: " + reply.getText());
+ }
+ appendMessage("------");
+ });
+ }
+
+ private void appendMessage(final String message) {
+ handler.post(
+ () -> {
+ messageTextView.append(message + "\n");
+ });
+ }
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java
new file mode 100644
index 0000000000..3357fd17c1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.smartreply;
+
+import android.support.annotation.Keep;
+
+/**
+ * SmartReply contains predicted message, and confidence.
+ *
+ * <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated.
+ */
+@Keep
+public class SmartReply {
+
+ private final String text;
+ private final float score;
+
+ @Keep
+ public SmartReply(String text, float score) {
+ this.text = text;
+ this.score = score;
+ }
+
+ public String getText() {
+ return text;
+ }
+
+ public float getScore() {
+ return score;
+ }
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
new file mode 100644
index 0000000000..d5b1ac0ffb
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.smartreply;
+
+import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.support.annotation.Keep;
+import android.support.annotation.WorkerThread;
+import android.util.Log;
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Interface to load TfLite model and provide predictions. */
+public class SmartReplyClient implements AutoCloseable {
+ private static final String TAG = "SmartReplyDemo";
+ private static final String MODEL_PATH = "smartreply.tflite";
+ private static final String BACKOFF_PATH = "backoff_response.txt";
+ private static final String JNI_LIB = "smartreply_jni";
+
+ private final Context context;
+ private long storage;
+ private MappedByteBuffer model;
+
+ private volatile boolean isLibraryLoaded;
+
+ public SmartReplyClient(Context context) {
+ this.context = context;
+ }
+
+ public boolean isLoaded() {
+ return storage != 0;
+ }
+
+ @WorkerThread
+ public synchronized void loadModel() {
+ if (!isLibraryLoaded) {
+ System.loadLibrary(JNI_LIB);
+ isLibraryLoaded = true;
+ }
+
+ try {
+ model = loadModelFile();
+ String[] backoff = loadBackoffList();
+ storage = loadJNI(model, backoff);
+ } catch (IOException e) {
+ Log.e(TAG, "Fail to load model", e);
+ return;
+ }
+ }
+
+ @WorkerThread
+ public synchronized SmartReply[] predict(String[] input) {
+ if (storage != 0) {
+ return predictJNI(storage, input);
+ } else {
+ return new SmartReply[] {};
+ }
+ }
+
+ @WorkerThread
+ public synchronized void unloadModel() {
+ close();
+ }
+
+ @Override
+ public synchronized void close() {
+ if (storage != 0) {
+ unloadJNI(storage);
+ storage = 0;
+ }
+ }
+
+ private MappedByteBuffer loadModelFile() throws IOException {
+ AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ try {
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ } finally {
+ inputStream.close();
+ }
+ }
+
+ private String[] loadBackoffList() throws IOException {
+ List<String> labelList = new ArrayList<String>();
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)));
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (!line.isEmpty()) {
+ labelList.add(line);
+ }
+ }
+ reader.close();
+ String[] ans = new String[labelList.size()];
+ labelList.toArray(ans);
+ return ans;
+ }
+
+ @Keep
+ private native long loadJNI(MappedByteBuffer buffer, String[] backoff);
+
+ @Keep
+ private native SmartReply[] predictJNI(long storage, String[] text);
+
+ @Keep
+ private native void unloadJNI(long storage);
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml
new file mode 100644
index 0000000000..23b4cadc00
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml
@@ -0,0 +1,44 @@
+<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:tools="http://schemas.android.com/tools"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:orientation="vertical">
+
+ <LinearLayout
+ android:layout_width="fill_parent"
+ android:layout_height="0dp"
+ android:padding="5dip"
+ android:layout_weight="3">
+
+ <TextView
+ android:id="@+id/message_text"
+ android:layout_width="fill_parent"
+ android:layout_height="fill_parent"
+ android:scrollbars="vertical"
+ android:gravity="bottom"/>
+ </LinearLayout>
+
+ <LinearLayout
+ android:layout_width="fill_parent"
+ android:layout_height="0dp"
+ android:padding="5dip"
+ android:layout_weight="1">
+
+ <EditText
+ android:id="@+id/message_input"
+ android:layout_width="0dp"
+ android:layout_height="fill_parent"
+ android:layout_weight="6"
+ android:scrollbars="vertical"
+ android:hint="Enter Text"
+ android:gravity="top"
+ android:inputType="text"/>
+ <Button
+ android:id="@+id/send_button"
+ android:layout_width="0dp"
+ android:layout_height="fill_parent"
+ android:layout_weight="2"
+ android:text="Send" />
+ </LinearLayout>
+
+</LinearLayout>
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc
new file mode 100644
index 0000000000..f158cc511a
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+const char kIllegalStateException[] = "java/lang/IllegalStateException";
+
+using tflite::custom::smartreply::GetSegmentPredictions;
+using tflite::custom::smartreply::PredictorResponse;
+
+template <typename T>
+T CheckNotNull(JNIEnv* env, T&& t) {
+ if (t == nullptr) {
+ env->ThrowNew(env->FindClass(kIllegalStateException), "");
+ return nullptr;
+ }
+ return std::forward<T>(t);
+}
+
+std::vector<std::string> jniStringArrayToVector(JNIEnv* env,
+ jobjectArray string_array) {
+ int count = env->GetArrayLength(string_array);
+ std::vector<std::string> result;
+ for (int i = 0; i < count; i++) {
+ auto jstr =
+ reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i));
+ const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
+ result.emplace_back(std::string(raw_str));
+ env->ReleaseStringUTFChars(jstr, raw_str);
+ }
+ return result;
+}
+
+struct JNIStorage {
+ std::vector<std::string> backoff_list;
+ std::unique_ptr<::tflite::FlatBufferModel> model;
+};
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_com_example_android_smartreply_SmartReplyClient_loadJNI(
+ JNIEnv* env, jobject thiz, jobject model_buffer,
+ jobjectArray backoff_list) {
+ const char* buf =
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
+ jlong capacity = env->GetDirectBufferCapacity(model_buffer);
+
+ JNIStorage* storage = new JNIStorage;
+ storage->model = tflite::FlatBufferModel::BuildFromBuffer(
+ buf, static_cast<size_t>(capacity));
+ storage->backoff_list = jniStringArrayToVector(env, backoff_list);
+
+ if (!storage->model) {
+ delete storage;
+ env->ThrowNew(env->FindClass(kIllegalStateException), "");
+ return 0;
+ }
+ return reinterpret_cast<jlong>(storage);
+}
+
+extern "C" JNIEXPORT jobjectArray JNICALL
+Java_com_example_android_smartreply_SmartReplyClient_predictJNI(
+ JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) {
+ // Predict
+ if (storage_ptr == 0) {
+ return nullptr;
+ }
+ JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
+ if (storage == nullptr) {
+ return nullptr;
+ }
+ std::vector<PredictorResponse> responses;
+ GetSegmentPredictions(jniStringArrayToVector(env, input_text),
+ *storage->model, {storage->backoff_list}, &responses);
+
+ // Create a SmartReply[] to return back to Java
+ jclass smart_reply_class = CheckNotNull(
+ env, env->FindClass("com/example/android/smartreply/SmartReply"));
+ if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ jmethodID smart_reply_ctor = CheckNotNull(
+ env,
+ env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V"));
+ if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ jobjectArray array = CheckNotNull(
+ env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr));
+ if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ for (int i = 0; i < responses.size(); i++) {
+ jstring text =
+ CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data()));
+ if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text,
+ responses[i].GetScore());
+ env->SetObjectArrayElement(array, i, reply);
+ }
+ return array;
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_com_example_android_smartreply_SmartReplyClient_unloadJNI(
+ JNIEnv* env, jobject thiz, jlong storage_ptr) {
+ if (storage_ptr != 0) {
+ JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
+ delete storage;
+ }
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
index 1c422b659a..f97a6486d6 100644
--- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include <algorithm>
#include <map>
-#include "re2/re2.h"
+
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/string_util.h"
@@ -81,7 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* label = GetOutput(context, node, 0);
TfLiteTensor* weight = GetOutput(context, node, 1);
- std::map<int64, int> feature_id_counts;
+ std::map<int64_t, int> feature_id_counts;
for (int i = 0; i < num_strings; i++) {
// Use fingerprint of feature name as id.
auto strref = tflite::GetString(input, i);
@@ -91,10 +91,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
continue;
}
- int64 feature_id =
+ int64_t feature_id =
::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
-
- label->data.i32[i] = static_cast<int32>(feature_id);
+ label->data.i32[i] = static_cast<int32_t>(feature_id);
weight->data.f[i] =
std::count(strref.str, strref.str + strref.len, ' ') + 1;
}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
index d0dc2a35a7..c55ac9f52f 100644
--- a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
@@ -21,7 +21,10 @@ limitations under the License.
// Output:
// Output[0]: Normalized sentence. string[1]
//
-#include "absl/strings/ascii.h"
+
+#include <algorithm>
+#include <string>
+
#include "absl/strings/str_cat.h"
#include "absl/strings/strip.h"
#include "re2/re2.h"
@@ -50,7 +53,7 @@ const std::map<string, string>* kRegexTransforms =
static const char kStartToken[] = "<S>";
static const char kEndToken[] = "<E>";
-static const int32 kMaxInputChars = 300;
+static const int32_t kMaxInputChars = 300;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc
index a28222213e..6da5cc8eec 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor.cc
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc
@@ -30,7 +30,7 @@ namespace custom {
namespace smartreply {
// Split sentence into segments (using punctuation).
-std::vector<string> SplitSentence(const string& input) {
+std::vector<std::string> SplitSentence(const std::string& input) {
string result(input);
RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
@@ -38,12 +38,13 @@ std::vector<string> SplitSentence(const string& input) {
RE2::GlobalReplace(&result, "[ ]+", " ");
RE2::GlobalReplace(&result, "\t+$", "");
- return strings::Split(result, '\t');
+ return absl::StrSplit(result, '\t');
}
// Predict with TfLite model.
-void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
- std::map<string, float>* response_map) {
+void ExecuteTfLite(const std::string& sentence,
+ ::tflite::Interpreter* interpreter,
+ std::map<std::string, float>* response_map) {
{
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
tflite::DynamicBuffer buf;
@@ -67,8 +68,8 @@ void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
}
void GetSegmentPredictions(
- const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
- const SmartReplyConfig& config,
+ const std::vector<std::string>& input,
+ const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses) {
// Initialize interpreter
std::unique_ptr<::tflite::Interpreter> interpreter;
@@ -82,10 +83,10 @@ void GetSegmentPredictions(
}
// Execute Tflite Model
- std::map<string, float> response_map;
- std::vector<string> sentences;
- for (const string& str : input) {
- std::vector<string> splitted_str = SplitSentence(str);
+ std::map<std::string, float> response_map;
+ std::vector<std::string> sentences;
+ for (const std::string& str : input) {
+ std::vector<std::string> splitted_str = SplitSentence(str);
sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
}
for (const auto& sentence : sentences) {
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
index 3b9a2b32e1..d17323a3f9 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor.h
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -34,7 +34,7 @@ struct SmartReplyConfig;
// With a given string as input, predict the response with a Tflite model.
// When config.backoff_response is not empty, predictor_responses will be filled
// with messagees from backoff response.
-void GetSegmentPredictions(const std::vector<string>& input,
+void GetSegmentPredictions(const std::vector<std::string>& input,
const ::tflite::FlatBufferModel& model,
const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses);
@@ -43,17 +43,17 @@ void GetSegmentPredictions(const std::vector<string>& input,
// It includes messages, and confidence.
class PredictorResponse {
public:
- PredictorResponse(const string& response_text, float score) {
+ PredictorResponse(const std::string& response_text, float score) {
response_text_ = response_text;
prediction_score_ = score;
}
// Accessor methods.
- const string& GetText() const { return response_text_; }
+ const std::string& GetText() const { return response_text_; }
float GetScore() const { return prediction_score_; }
private:
- string response_text_ = "";
+ std::string response_text_ = "";
float prediction_score_ = 0.0;
};
@@ -65,9 +65,9 @@ struct SmartReplyConfig {
float backoff_confidence;
// Backoff responses are used when predicted responses cannot fulfill the
// list.
- const std::vector<string>& backoff_responses;
+ const std::vector<std::string>& backoff_responses;
- SmartReplyConfig(std::vector<string> backoff_responses)
+ SmartReplyConfig(std::vector<std::string> backoff_responses)
: num_response(kDefaultNumResponse),
backoff_confidence(kDefaultBackoffConfidence),
backoff_responses(backoff_responses) {}
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
index 2fa9923bc9..97d3c650e2 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -18,12 +18,12 @@ limitations under the License.
#include <fstream>
#include <unordered_set>
-#include "base/logging.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/contrib/lite/models/test_utils.h"
+#include "tensorflow/contrib/lite/string_util.h"
namespace tflite {
namespace custom {
@@ -65,7 +65,6 @@ TEST_F(PredictorTest, GetSegmentPredictions) {
float max = 0;
for (const auto &item : predictions) {
- LOG(INFO) << "Response: " << item.GetText();
if (item.GetScore() > max) {
max = item.GetScore();
}
@@ -86,7 +85,6 @@ TEST_F(PredictorTest, TestTwoSentences) {
float max = 0;
for (const auto &item : predictions) {
- LOG(INFO) << "Response: " << item.GetText();
if (item.GetScore() > max) {
max = item.GetScore();
}
@@ -119,7 +117,7 @@ TEST_F(PredictorTest, BatchTest) {
string line;
std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
while (std::getline(fin, line)) {
- const std::vector<string> &fields = strings::Split(line, '\t');
+ const std::vector<string> fields = absl::StrSplit(line, '\t');
if (fields.empty()) {
continue;
}
@@ -139,9 +137,8 @@ TEST_F(PredictorTest, BatchTest) {
fields.begin() + 1, fields.end())));
}
- LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
- LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
EXPECT_EQ(total_triggers, total_items);
+ EXPECT_GE(total_responses, total_triggers);
}
} // namespace
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 21b32d8434..751682215b 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -13,6 +13,7 @@ tf_cc_binary(
"//tensorflow/contrib/lite/tools:gen_op_registration",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
index 1b28b8bcd9..17b514c916 100644
--- a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
@@ -13,30 +13,50 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cassert>
#include <fstream>
+#include <map>
#include <sstream>
#include <string>
#include <vector>
+#include "absl/strings/strip.h"
#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
+const char kInputModelFlag[] = "input_model";
+const char kOutputRegistrationFlag[] = "output_registration";
+const char kTfLitePathFlag[] = "tflite_path";
+
using tensorflow::Flag;
using tensorflow::Flags;
using tensorflow::string;
+void ParseFlagAndInit(int argc, char** argv, string* input_model,
+ string* output_registration, string* tflite_path) {
+ std::vector<tensorflow::Flag> flag_list = {
+ Flag(kInputModelFlag, input_model, "path to the tflite model"),
+ Flag(kOutputRegistrationFlag, output_registration,
+ "filename for generated registration code"),
+ Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
+ };
+
+ Flags::Parse(&argc, argv, flag_list);
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+}
+
namespace {
-void GenerateFileContent(const string& filename,
+void GenerateFileContent(const std::string& tflite_path,
+ const std::string& filename,
const std::vector<string>& builtin_ops,
const std::vector<string>& custom_ops) {
std::ofstream fout(filename);
- fout << "#include "
- "\"third_party/tensorflow/contrib/lite/model.h\"\n";
- fout << "#include "
- "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n";
+ fout << "#include \"" << tflite_path << "/model.h\"\n";
+ fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n";
+
fout << "namespace tflite {\n";
fout << "namespace ops {\n";
if (!builtin_ops.empty()) {
@@ -78,22 +98,20 @@ void GenerateFileContent(const string& filename,
int main(int argc, char** argv) {
string input_model;
string output_registration;
- std::vector<tensorflow::Flag> flag_list = {
- Flag("input_model", &input_model, "path to the tflite model"),
- Flag("output_registration", &output_registration,
- "filename for generated registration code"),
- };
- Flags::Parse(&argc, argv, flag_list);
+ string tflite_path;
+ ParseFlagAndInit(argc, argv, &input_model, &output_registration,
+ &tflite_path);
- tensorflow::port::InitMain(argv[0], &argc, &argv);
std::vector<string> builtin_ops;
std::vector<string> custom_ops;
-
std::ifstream fin(input_model);
std::stringstream content;
content << fin.rdbuf();
- const ::tflite::Model* model = ::tflite::GetModel(content.str().data());
+ // Need to store content data first, otherwise, it won't work in bazel.
+ string content_str = content.str();
+ const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
- GenerateFileContent(output_registration, builtin_ops, custom_ops);
+ GenerateFileContent(tflite_path, output_registration, builtin_ops,
+ custom_ops);
return 0;
}
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
index be60cf476d..906553da57 100644
--- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -46,7 +46,7 @@ class MutableOpResolver : public OpResolver {
void AddCustom(const char* name, TfLiteRegistration* registration);
private:
- std::map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_;
+ std::map<int, TfLiteRegistration*> builtins_;
std::map<std::string, TfLiteRegistration*> custom_ops_;
};
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 25e036e24c..11f9aa2259 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -207,11 +207,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
native.http_archive(
name = "com_googlesource_code_re2",
urls = [
- "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
- "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
+ "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
+ "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
+
],
- sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f",
- strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c",
+ sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b",
+ strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857",
)
native.http_archive(
@@ -800,3 +801,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
],
)
+
+ native.new_http_archive(
+ name = "tflite_smartreply",
+ build_file = str(Label("//third_party:tflite_smartreply.BUILD")),
+ sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
+ urls = [
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip"
+ ],
+ )
diff --git a/third_party/tflite_smartreply.BUILD b/third_party/tflite_smartreply.BUILD
new file mode 100644
index 0000000000..75663eff48
--- /dev/null
+++ b/third_party/tflite_smartreply.BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "model_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+)