aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-10-05 15:43:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 15:55:51 -0700
commit4aad5382f0e7148d8489d24d8355b828b3f7811b (patch)
tree3d0b68908568258af5a33f72539a5c3ba2b89bf1 /tensorflow/contrib
parent5ac6e1e4b8318bad2f2bc7e5a08a58a7ed31e4c6 (diff)
Internal change
PiperOrigin-RevId: 215978771
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/java/BUILD95
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl5
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java20
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java46
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java14
5 files changed, 153 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 098ba7e773..e68cd26f81 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni")
+JAVA_SRCS = glob([
+ "src/main/java/org/tensorflow/lite/*.java",
+])
+
# Building tensorflow-lite.aar including 4 variants of .so
# To build an aar for release, run below command:
# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
@@ -20,28 +24,38 @@ aar_with_jni(
android_library = ":tensorflowlite",
)
+# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite.
+aar_with_jni(
+ name = "tensorflow-lite-flex",
+ android_library = ":tensorflowlite_flex",
+)
+
android_library(
name = "tensorflowlite",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflowlite_native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
+android_library(
+ name = "tensorflowlite_flex",
+ srcs = JAVA_SRCS,
manifest = "AndroidManifest.xml",
visibility = ["//visibility:public"],
deps = [
- ":tflite_runtime",
+ ":tensorflowlite_native_flex",
"@org_checkerframework_qual",
],
)
android_library(
name = "tensorflowlite_java",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
visibility = ["//visibility:public"],
deps = [
"@org_checkerframework_qual",
@@ -50,16 +64,23 @@ android_library(
java_library(
name = "tensorflowlitelib",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
javacopts = JAVACOPTS,
visibility = ["//visibility:public"],
deps = [
":libtensorflowlite_jni.so",
- "//tensorflow/contrib/lite/java/src/main/native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite.
+java_library(
+ name = "tensorflowlitelib_flex",
+ srcs = JAVA_SRCS,
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_flex_jni.so",
"@org_checkerframework_qual",
],
)
@@ -72,7 +93,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.TensorFlowLiteTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -87,7 +107,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.DataTypeTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -110,7 +129,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -125,13 +143,13 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/mobilenet.tflite.bin",
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
test_class = "org.tensorflow.lite.InterpreterTest",
visibility = ["//visibility:private"],
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -139,6 +157,24 @@ java_test(
)
java_test(
+ name = "InterpreterFlexTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.lite.InterpreterFlexTest",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":tensorflowlitelib_flex",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
name = "TensorTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"],
@@ -164,14 +200,29 @@ filegroup(
)
cc_library(
- name = "tflite_runtime",
+ name = "tensorflowlite_native",
srcs = ["libtensorflowlite_jni.so"],
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "tensorflowlite_native_flex",
+ srcs = ["libtensorflowlite_flex_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
tflite_jni_binary(
name = "libtensorflowlite_jni.so",
deps = [
"//tensorflow/contrib/lite/java/src/main/native",
],
)
+
+# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite.
+tflite_jni_binary(
+ name = "libtensorflowlite_flex_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index 9d2aead266..360d622b1b 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -30,7 +30,10 @@ EOF
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
- tags = ["manual"],
+ tags = [
+ "manual",
+ "no_cuda_on_cpu_tap",
+ ],
)
native.genrule(
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 711638a9f9..d5447b3bf8 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -18,7 +18,8 @@ package org.tensorflow.lite;
/** Static utility methods loading the TensorFlowLite runtime. */
public final class TensorFlowLite {
- private static final String LIBNAME = "tensorflowlite_jni";
+ private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
+ private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
private TensorFlowLite() {}
@@ -29,13 +30,24 @@ public final class TensorFlowLite {
* Load the TensorFlowLite runtime C library.
*/
static boolean init() {
+ Throwable primaryLibException;
try {
- System.loadLibrary(LIBNAME);
+ System.loadLibrary(PRIMARY_LIBNAME);
return true;
} catch (UnsatisfiedLinkError e) {
- System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
- return false;
+ primaryLibException = e;
}
+
+ try {
+ System.loadLibrary(FALLBACK_LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ // If the fallback fails, log the error for the primary load instead.
+ System.err.println(
+ "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
+ }
+
+ return false;
}
static {
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
new file mode 100644
index 0000000000..2791c3864b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.io.File;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
+ * have TensorFlow ops.
+ */
+@RunWith(JUnit4.class)
+public final class InterpreterFlexTest {
+
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
+ /** Smoke test validating that flex model loading works when the flex delegate is linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.run(new float[1], new float[1]);
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index a98fca0132..f8b73c7cf3 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -43,6 +43,9 @@ public final class InterpreterTest {
private static final File MOBILENET_MODEL_FILE =
new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
@Test
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
@@ -345,4 +348,15 @@ public final class InterpreterTest {
interpreter.close();
interpreter.close();
}
+
+ /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try {
+ new Interpreter(FLEX_MODEL_FILE);
+ fail();
+ } catch (IllegalStateException e) {
+ // Expected failure.
+ }
+ }
}