aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-25 09:32:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:44:19 -0700
commit954d6a0ace9b96cdd54659b99e9378a1138a7266 (patch)
treef0c904fb3137dcd9ff0f0b96cdec3c47297f292f /tensorflow/contrib/lite/java
parentaee2ab023837adbfc61253ffec07f8d2dcd6c2a8 (diff)
Add Interpreter.Options Java API for interpreter configuration
PiperOrigin-RevId: 214451901
Diffstat (limited to 'tensorflow/contrib/lite/java')
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java2
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java93
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java36
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java12
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java9
5 files changed, 103 insertions, 49 deletions
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index 4cf51bb0fa..fd610b054f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -74,7 +74,7 @@ public class OvicClassifier {
}
labelList = loadLabelList(labelInputStream);
// OVIC uses one thread for CPU inference.
- tflite = new Interpreter(model, 1);
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
inputDims = TestHelper.getInputDims(tflite, 0);
if (inputDims.length != 4) {
throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index b84720ae8e..ffb04496cb 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -17,7 +17,6 @@ package org.tensorflow.lite;
import java.io.File;
import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
@@ -56,16 +55,36 @@ import org.checkerframework.checker.nullness.qual.NonNull;
*/
public final class Interpreter implements AutoCloseable {
+ /** An options class for controlling runtime interpreter behavior. */
+ public static class Options {
+ public Options() {}
+
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
+ * platform-dependent value.
+ */
+ public Options setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
+ public Options setUseNNAPI(boolean useNNAPI) {
+ this.useNNAPI = useNNAPI;
+ return this;
+ }
+
+ int numThreads = -1;
+ boolean useNNAPI = false;
+ }
+
/**
* Initializes a {@code Interpreter}
*
* @param modelFile: a File of a pre-trained TF Lite model.
*/
public Interpreter(@NonNull File modelFile) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ this(modelFile, /*options = */ null);
}
/**
@@ -73,12 +92,22 @@ public final class Interpreter implements AutoCloseable {
*
* @param modelFile: a file of a pre-trained TF Lite model
* @param numThreads: number of threads to use for inference
+ * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
+ * be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull File modelFile, int numThreads) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+ this(modelFile, new Options().setNumThreads(numThreads));
+ }
+
+ /**
+ * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+ *
+ * @param modelFile: a file of a pre-trained TF Lite model
+ * @param options: a set of options for customizing interpreter behavior
+ */
+ public Interpreter(@NonNull File modelFile, Options options) {
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
/**
@@ -89,7 +118,7 @@ public final class Interpreter implements AutoCloseable {
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer) {
- wrapper = new NativeInterpreterWrapper(byteBuffer);
+ this(byteBuffer, /* options= */ null);
}
/**
@@ -99,30 +128,25 @@ public final class Interpreter implements AutoCloseable {
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
- public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads);
- }
-
- /**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ @Deprecated
+ public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
+ this(byteBuffer, new Options().setNumThreads(numThreads));
}
/**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
- * specifies the number of threads used for inference.
+ * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
+ * {@link #Options}.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
+ * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
+ * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
+ wrapper = new NativeInterpreterWrapper(byteBuffer, options);
}
/**
@@ -240,12 +264,25 @@ public final class Interpreter implements AutoCloseable {
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
- /** Turns on/off Android NNAPI for hardware acceleration when it is available. */
+ /**
+ * Turns on/off Android NNAPI for hardware acceleration when it is available.
+ *
+ * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API.
+ * This method will be removed in a future release.
+ */
+ @Deprecated
public void setUseNNAPI(boolean useNNAPI) {
checkNotClosed();
wrapper.setUseNNAPI(useNNAPI);
}
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading.
+ *
+ * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
+ * multi-threading. This method will be removed in a future release.
+ */
+ @Deprecated
public void setNumThreads(int numThreads) {
checkNotClosed();
wrapper.setNumThreads(numThreads);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index fa25082304..6feff9a618 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * A wrapper wraps native interpreter and controls model execution.
+ * An internal wrapper that wraps native interpreter and controls model execution.
*
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
@@ -32,36 +32,29 @@ import java.util.Map;
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
- this(modelPath, /* numThreads= */ -1);
+ this(modelPath, /* options= */ null);
}
- NativeInterpreterWrapper(String modelPath, int numThreads) {
+ NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should
- * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code
- * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct
- * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
- this(byteBuffer, /* numThreads= */ -1);
+ this(byteBuffer, /* options= */ null);
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the
- * number of inference threads. The ByteBuffer should not be modified after the construction of a
- * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code
- * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of
- * nativeOrder() that contains the bytes content of a model.
- */
- NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) {
+ NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
&& (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
@@ -72,10 +65,13 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelByteBuffer = buffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.useNNAPI) {
+ setUseNNAPI(options.useNNAPI);
+ }
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 9070b788b6..fefaa88911 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -55,6 +55,18 @@ public final class InterpreterTest {
}
@Test
+ public void testInterpreterWithOptions() throws Exception {
+ Interpreter interpreter =
+ new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.close();
+ }
+
+ @Test
public void testRunWithMappedByteBufferModel() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 9c4a5acd79..270bd6703a 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testConstructorWithOptions() {
+ NativeInterpreterWrapper wrapper =
+ new NativeInterpreterWrapper(
+ FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
public void testConstructorWithInvalidModel() {
try {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);