diff options
author | Jared Duke <jdduke@google.com> | 2018-10-05 15:43:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 15:55:51 -0700 |
commit | 4aad5382f0e7148d8489d24d8355b828b3f7811b (patch) | |
tree | 3d0b68908568258af5a33f72539a5c3ba2b89bf1 /tensorflow/contrib/lite/java/src | |
parent | 5ac6e1e4b8318bad2f2bc7e5a08a58a7ed31e4c6 (diff) |
Internal change
PiperOrigin-RevId: 215978771
Diffstat (limited to 'tensorflow/contrib/lite/java/src')
3 files changed, 76 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java index 711638a9f9..d5447b3bf8 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -18,7 +18,8 @@ package org.tensorflow.lite; /** Static utility methods loading the TensorFlowLite runtime. */ public final class TensorFlowLite { - private static final String LIBNAME = "tensorflowlite_jni"; + private static final String PRIMARY_LIBNAME = "tensorflowlite_jni"; + private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni"; private TensorFlowLite() {} @@ -29,13 +30,24 @@ public final class TensorFlowLite { * Load the TensorFlowLite runtime C library. */ static boolean init() { + Throwable primaryLibException; try { - System.loadLibrary(LIBNAME); + System.loadLibrary(PRIMARY_LIBNAME); return true; } catch (UnsatisfiedLinkError e) { - System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); - return false; + primaryLibException = e; } + + try { + System.loadLibrary(FALLBACK_LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + // If the fallback fails, log the error for the primary load instead. + System.err.println( + "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage()); + } + + return false; } static { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java new file mode 100644 index 0000000000..2791c3864b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.File; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that + * have TensorFlow ops. + */ +@RunWith(JUnit4.class) +public final class InterpreterFlexTest { + + private static final File FLEX_MODEL_FILE = + new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + + /** Smoke test validating that flex model loading works when the flex delegate is linked. */ + @Test + public void testFlexModel() throws Exception { + try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) { + assertThat(interpreter.getInputTensorCount()).isEqualTo(4); + assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + assertThat(interpreter.getOutputTensorCount()).isEqualTo(4); + assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + interpreter.run(new float[1], new float[1]); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index a98fca0132..f8b73c7cf3 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -43,6 +43,9 @@ public final class InterpreterTest { private static final File MOBILENET_MODEL_FILE = new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + private static final File FLEX_MODEL_FILE = + new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + @Test public void testInterpreter() throws Exception { Interpreter interpreter = new Interpreter(MODEL_FILE); @@ -345,4 +348,15 @@ public final class InterpreterTest { interpreter.close(); interpreter.close(); } + + /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */ + @Test + public void testFlexModel() throws Exception { + try { + new Interpreter(FLEX_MODEL_FILE); + fail(); + } catch (IllegalStateException e) { + // Expected failure. + } + } } |