aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-10-05 15:43:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 15:55:51 -0700
commit4aad5382f0e7148d8489d24d8355b828b3f7811b (patch)
tree3d0b68908568258af5a33f72539a5c3ba2b89bf1 /tensorflow/contrib/lite/java/src
parent5ac6e1e4b8318bad2f2bc7e5a08a58a7ed31e4c6 (diff)
Internal change
PiperOrigin-RevId: 215978771
Diffstat (limited to 'tensorflow/contrib/lite/java/src')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java20
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java46
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java14
3 files changed, 76 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 711638a9f9..d5447b3bf8 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -18,7 +18,8 @@ package org.tensorflow.lite;
/** Static utility methods loading the TensorFlowLite runtime. */
public final class TensorFlowLite {
- private static final String LIBNAME = "tensorflowlite_jni";
+ private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
+ private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
private TensorFlowLite() {}
@@ -29,13 +30,24 @@ public final class TensorFlowLite {
* Load the TensorFlowLite runtime C library.
*/
static boolean init() {
+ Throwable primaryLibException;
try {
- System.loadLibrary(LIBNAME);
+ System.loadLibrary(PRIMARY_LIBNAME);
return true;
} catch (UnsatisfiedLinkError e) {
- System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
- return false;
+ primaryLibException = e;
}
+
+ try {
+ System.loadLibrary(FALLBACK_LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ // If the fallback fails, log the error for the primary load instead.
+ System.err.println(
+ "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
+ }
+
+ return false;
}
static {
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
new file mode 100644
index 0000000000..2791c3864b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.io.File;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
+ * have TensorFlow ops.
+ */
+@RunWith(JUnit4.class)
+public final class InterpreterFlexTest {
+
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
+ /** Smoke test validating that flex model loading works when the flex delegate is linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.run(new float[1], new float[1]);
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index a98fca0132..f8b73c7cf3 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -43,6 +43,9 @@ public final class InterpreterTest {
private static final File MOBILENET_MODEL_FILE =
new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
@Test
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
@@ -345,4 +348,15 @@ public final class InterpreterTest {
interpreter.close();
interpreter.close();
}
+
+ /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try {
+ new Interpreter(FLEX_MODEL_FILE);
+ fail();
+ } catch (IllegalStateException e) {
+ // Expected failure.
+ }
+ }
}