diff options
author | Jared Duke <jdduke@google.com> | 2018-09-25 09:32:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 09:44:19 -0700 |
commit | 954d6a0ace9b96cdd54659b99e9378a1138a7266 (patch) | |
tree | f0c904fb3137dcd9ff0f0b96cdec3c47297f292f /tensorflow/contrib | |
parent | aee2ab023837adbfc61253ffec07f8d2dcd6c2a8 (diff) |
Add Interpreter.Options Java API for interpreter configuration
PiperOrigin-RevId: 214451901
Diffstat (limited to 'tensorflow/contrib')
5 files changed, 103 insertions, 49 deletions
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java index 4cf51bb0fa..fd610b054f 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java @@ -74,7 +74,7 @@ public class OvicClassifier { } labelList = loadLabelList(labelInputStream); // OVIC uses one thread for CPU inference. - tflite = new Interpreter(model, 1); + tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1)); inputDims = TestHelper.getInputDims(tflite, 0); if (inputDims.length != 4) { throw new RuntimeException("The model's input dimensions must be 4 (BWHC)."); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index b84720ae8e..ffb04496cb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -17,7 +17,6 @@ package org.tensorflow.lite; import java.io.File; import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; import org.checkerframework.checker.nullness.qual.NonNull; @@ -56,16 +55,36 @@ import org.checkerframework.checker.nullness.qual.NonNull; */ public final class Interpreter implements AutoCloseable { + /** An options class for controlling runtime interpreter behavior. */ + public static class Options { + public Options() {} + + /** + * Sets the number of threads to be used for ops that support multi-threading. Defaults to a + * platform-dependent value. + */ + public Options setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ + public Options setUseNNAPI(boolean useNNAPI) { + this.useNNAPI = useNNAPI; + return this; + } + + int numThreads = -1; + boolean useNNAPI = false; + } + /** * Initializes a {@code Interpreter} * * @param modelFile: a File of a pre-trained TF Lite model. */ public Interpreter(@NonNull File modelFile) { - if (modelFile == null) { - return; - } - wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + this(modelFile, /*options = */ null); } /** @@ -73,12 +92,22 @@ public final class Interpreter implements AutoCloseable { * * @param modelFile: a file of a pre-trained TF Lite model * @param numThreads: number of threads to use for inference + * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will + * be removed in a future release. */ + @Deprecated public Interpreter(@NonNull File modelFile, int numThreads) { - if (modelFile == null) { - return; - } - wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads); + this(modelFile, new Options().setNumThreads(numThreads)); + } + + /** + * Initializes a {@code Interpreter} and specifies the number of threads used for inference. + * + * @param modelFile: a file of a pre-trained TF Lite model + * @param options: a set of options for customizing interpreter behavior + */ + public Interpreter(@NonNull File modelFile, Options options) { + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options); } /** @@ -89,7 +118,7 @@ public final class Interpreter implements AutoCloseable { * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ public Interpreter(@NonNull ByteBuffer byteBuffer) { - wrapper = new NativeInterpreterWrapper(byteBuffer); + this(byteBuffer, /* options= */ null); } /** @@ -99,30 +128,25 @@ public final class Interpreter implements AutoCloseable { * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. - */ - public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { - wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); - } - - /** - * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. * - * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code - * Interpreter}. + * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method + * will be removed in a future release. */ - public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { - wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + @Deprecated + public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { + this(byteBuffer, new Options().setNumThreads(numThreads)); } /** - * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and - * specifies the number of threads used for inference. + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom + * {@link #Options}. * - * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code - * Interpreter}. + * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ - public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { - wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) { + wrapper = new NativeInterpreterWrapper(byteBuffer, options); } /** @@ -240,12 +264,25 @@ public final class Interpreter implements AutoCloseable { return wrapper.getLastNativeInferenceDurationNanoseconds(); } - /** Turns on/off Android NNAPI for hardware acceleration when it is available. */ + /** + * Turns on/off Android NNAPI for hardware acceleration when it is available. + * + * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API. + * This method will be removed in a future release. + */ + @Deprecated public void setUseNNAPI(boolean useNNAPI) { checkNotClosed(); wrapper.setUseNNAPI(useNNAPI); } + /** + * Sets the number of threads to be used for ops that support multi-threading. + * + * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread + * multi-threading. This method will be removed in a future release. + */ + @Deprecated public void setNumThreads(int numThreads) { checkNotClosed(); wrapper.setNumThreads(numThreads); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index fa25082304..6feff9a618 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -23,7 +23,7 @@ import java.util.HashMap; import java.util.Map; /** - * A wrapper wraps native interpreter and controls model execution. + * An internal wrapper that wraps native interpreter and controls model execution. * * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be * explicitly freed by invoking the {@link #close()} method when the {@code @@ -32,36 +32,29 @@ import java.util.Map; final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { - this(modelPath, /* numThreads= */ -1); + this(modelPath, /* options= */ null); } - NativeInterpreterWrapper(String modelPath, int numThreads) { + NativeInterpreterWrapper(String modelPath, Interpreter.Options options) { + if (options == null) { + options = new Interpreter.Options(); + } errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); + interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads); isMemoryAllocated = true; inputTensors = new Tensor[getInputCount(interpreterHandle)]; outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } - /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should - * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code - * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct - * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. - */ NativeInterpreterWrapper(ByteBuffer byteBuffer) { - this(byteBuffer, /* numThreads= */ -1); + this(byteBuffer, /* options= */ null); } - /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the - * number of inference threads. The ByteBuffer should not be modified after the construction of a - * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code - * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of - * nativeOrder() that contains the bytes content of a model. - */ - NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) { + NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) { + if (options == null) { + options = new Interpreter.Options(); + } if (buffer == null || (!(buffer instanceof MappedByteBuffer) && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { @@ -72,10 +65,13 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = buffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); + interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads); isMemoryAllocated = true; inputTensors = new Tensor[getInputCount(interpreterHandle)]; outputTensors = new Tensor[getOutputCount(interpreterHandle)]; + if (options.useNNAPI) { + setUseNNAPI(options.useNNAPI); + } } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 9070b788b6..fefaa88911 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -55,6 +55,18 @@ public final class InterpreterTest { } @Test + public void testInterpreterWithOptions() throws Exception { + Interpreter interpreter = + new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true)); + assertThat(interpreter).isNotNull(); + assertThat(interpreter.getInputTensorCount()).isEqualTo(1); + assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + assertThat(interpreter.getOutputTensorCount()).isEqualTo(1); + assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + interpreter.close(); + } + + @Test public void testRunWithMappedByteBufferModel() throws Exception { Path path = MODEL_FILE.toPath(); FileChannel fileChannel = diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 9c4a5acd79..270bd6703a 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest { } @Test + public void testConstructorWithOptions() { + NativeInterpreterWrapper wrapper = + new NativeInterpreterWrapper( + FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true)); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test public void testConstructorWithInvalidModel() { try { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); |