From 4aad5382f0e7148d8489d24d8355b828b3f7811b Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Fri, 5 Oct 2018 15:43:58 -0700 Subject: Internal change PiperOrigin-RevId: 215978771 --- tensorflow/contrib/lite/java/BUILD | 95 +++++++++++++++++----- tensorflow/contrib/lite/java/aar_with_jni.bzl | 5 +- .../java/org/tensorflow/lite/TensorFlowLite.java | 20 ++++- .../org/tensorflow/lite/InterpreterFlexTest.java | 46 +++++++++++ .../java/org/tensorflow/lite/InterpreterTest.java | 14 ++++ 5 files changed, 153 insertions(+), 27 deletions(-) create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 098ba7e773..e68cd26f81 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") +JAVA_SRCS = glob([ + "src/main/java/org/tensorflow/lite/*.java", +]) + # Building tensorflow-lite.aar including 4 variants of .so # To build an aar for release, run below command: # bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ @@ -20,28 +24,38 @@ aar_with_jni( android_library = ":tensorflowlite", ) +# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite. +aar_with_jni( + name = "tensorflow-lite-flex", + android_library = ":tensorflowlite_flex", +) + android_library( name = "tensorflowlite", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + ":tensorflowlite_native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite. +android_library( + name = "tensorflowlite_flex", + srcs = JAVA_SRCS, manifest = "AndroidManifest.xml", visibility = ["//visibility:public"], deps = [ - ":tflite_runtime", + ":tensorflowlite_native_flex", "@org_checkerframework_qual", ], ) android_library( name = "tensorflowlite_java", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, visibility = ["//visibility:public"], deps = [ "@org_checkerframework_qual", @@ -50,16 +64,23 @@ android_library( java_library( name = "tensorflowlitelib", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, javacopts = JAVACOPTS, visibility = ["//visibility:public"], deps = [ ":libtensorflowlite_jni.so", - "//tensorflow/contrib/lite/java/src/main/native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite. +java_library( + name = "tensorflowlitelib_flex", + srcs = JAVA_SRCS, + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_flex_jni.so", "@org_checkerframework_qual", ], ) @@ -72,7 +93,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -87,7 +107,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -110,7 +129,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -125,19 +143,37 @@ java_test( data = [ "src/testdata/add.bin", "src/testdata/mobilenet.tflite.bin", + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", ], javacopts = JAVACOPTS, tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", ], ) +java_test( + name = "InterpreterFlexTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"], + data = [ + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", + ], + javacopts = JAVACOPTS, + tags = ["no_oss"], + test_class = "org.tensorflow.lite.InterpreterFlexTest", + visibility = ["//visibility:private"], + deps = [ + ":tensorflowlitelib_flex", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", @@ -164,14 +200,29 @@ filegroup( ) cc_library( - name = "tflite_runtime", + name = "tensorflowlite_native", srcs = ["libtensorflowlite_jni.so"], visibility = ["//visibility:public"], ) +cc_library( + name = "tensorflowlite_native_flex", + srcs = ["libtensorflowlite_flex_jni.so"], + visibility = ["//visibility:public"], +) + tflite_jni_binary( name = "libtensorflowlite_jni.so", deps = [ "//tensorflow/contrib/lite/java/src/main/native", ], ) + +# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite. +tflite_jni_binary( + name = "libtensorflowlite_flex_jni.so", + deps = [ + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/java/src/main/native", + ], +) diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index 9d2aead266..360d622b1b 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -30,7 +30,10 @@ EOF # In some platforms we don't have an Android SDK/NDK and this target # can't be built. We need to prevent the build system from trying to # use the target in that case. - tags = ["manual"], + tags = [ + "manual", + "no_cuda_on_cpu_tap", + ], ) native.genrule( diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java index 711638a9f9..d5447b3bf8 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -18,7 +18,8 @@ package org.tensorflow.lite; /** Static utility methods loading the TensorFlowLite runtime. */ public final class TensorFlowLite { - private static final String LIBNAME = "tensorflowlite_jni"; + private static final String PRIMARY_LIBNAME = "tensorflowlite_jni"; + private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni"; private TensorFlowLite() {} @@ -29,13 +30,24 @@ public final class TensorFlowLite { * Load the TensorFlowLite runtime C library. */ static boolean init() { + Throwable primaryLibException; try { - System.loadLibrary(LIBNAME); + System.loadLibrary(PRIMARY_LIBNAME); return true; } catch (UnsatisfiedLinkError e) { - System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); - return false; + primaryLibException = e; } + + try { + System.loadLibrary(FALLBACK_LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + // If the fallback fails, log the error for the primary load instead. + System.err.println( + "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage()); + } + + return false; } static { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java new file mode 100644 index 0000000000..2791c3864b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.File; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that + * have TensorFlow ops. + */ +@RunWith(JUnit4.class) +public final class InterpreterFlexTest { + + private static final File FLEX_MODEL_FILE = + new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + + /** Smoke test validating that flex model loading works when the flex delegate is linked. */ + @Test + public void testFlexModel() throws Exception { + try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) { + assertThat(interpreter.getInputTensorCount()).isEqualTo(4); + assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + assertThat(interpreter.getOutputTensorCount()).isEqualTo(4); + assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + interpreter.run(new float[1], new float[1]); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index a98fca0132..f8b73c7cf3 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -43,6 +43,9 @@ public final class InterpreterTest { private static final File MOBILENET_MODEL_FILE = new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + private static final File FLEX_MODEL_FILE = + new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + @Test public void testInterpreter() throws Exception { Interpreter interpreter = new Interpreter(MODEL_FILE); @@ -345,4 +348,15 @@ public final class InterpreterTest { interpreter.close(); interpreter.close(); } + + /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */ + @Test + public void testFlexModel() throws Exception { + try { + new Interpreter(FLEX_MODEL_FILE); + fail(); + } catch (IllegalStateException e) { + // Expected failure. + } + } } -- cgit v1.2.3