From 6b6244c40197b34f49bb50aa52efb082380d4637 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2017 23:58:26 -0800 Subject: Build demo app for SmartReply PiperOrigin-RevId: 177559103 --- tensorflow/contrib/lite/models/smartreply/BUILD | 85 ++++++++++++++ .../demo/app/src/main/AndroidManifest.xml | 38 ++++++ .../lite/models/smartreply/demo/app/src/main/BUILD | 65 +++++++++++ .../smartreply/demo/app/src/main/assets/BUILD | 15 +++ .../demo/app/src/main/assets/backoff_response.txt | 16 +++ .../example/android/smartreply/MainActivity.java | 99 ++++++++++++++++ .../com/example/android/smartreply/SmartReply.java | 44 +++++++ .../android/smartreply/SmartReplyClient.java | 129 +++++++++++++++++++++ .../demo/app/src/main/res/layout/main_activity.xml | 44 +++++++ .../smartreply/demo/app/src/main/smartreply_jni.cc | 129 +++++++++++++++++++++ .../lite/models/smartreply/ops/extract_feature.cc | 9 +- .../lite/models/smartreply/ops/normalize.cc | 7 +- .../contrib/lite/models/smartreply/predictor.cc | 21 ++-- .../contrib/lite/models/smartreply/predictor.h | 12 +- .../lite/models/smartreply/predictor_test.cc | 9 +- 15 files changed, 693 insertions(+), 29 deletions(-) create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml create mode 100644 tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc (limited to 'tensorflow/contrib/lite/models') diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index fbdf19f205..733c3f4c7f 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -1,7 +1,92 @@ package(default_visibility = ["//visibility:public"]) +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + licenses(["notice"]) # Apache 2.0 +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..75ed9432c8 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000..f8767b443a --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000..3c882ffc43 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000..a0a5b46b5f --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000..02fec9ae5e --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000..3357fd17c1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000..d5b1ac0ffb --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000..23b4cadc00 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +