diff options
author | Jared Duke <jdduke@google.com> | 2018-09-25 17:26:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 17:33:59 -0700 |
commit | a7f14807417ea78aee8ea275536902f0aaa94fd4 (patch) | |
tree | 612e511178404cb3aed788efbfc03931c13a9828 /tensorflow/contrib/lite/java | |
parent | f97610daf89572e52912ddc5bf87576cc9e82f66 (diff) |
Reland "Add Interpreter.Options Java API for interpreter configuration"
The original CL broke the InterpreterTest due to use of a newly
deprecated API. This has been fixed, and deprecated API usage in the
samples has also been updated.
PiperOrigin-RevId: 214532691
Diffstat (limited to 'tensorflow/contrib/lite/java')
8 files changed, 141 insertions, 111 deletions
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 4f5662bc2d..3596e42011 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -58,9 +58,9 @@ import android.view.View; import android.view.ViewGroup; import android.widget.CompoundButton; import android.widget.NumberPicker; -import android.widget.ToggleButton; import android.widget.TextView; import android.widget.Toast; +import android.widget.ToggleButton; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -305,22 +305,24 @@ public class Camera2BasicFragment extends Fragment textView = (TextView) view.findViewById(R.id.text); toggle = (ToggleButton) view.findViewById(R.id.button); - toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { - public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { - classifier.setUseNNAPI(isChecked); - } - }); + toggle.setOnCheckedChangeListener( + new CompoundButton.OnCheckedChangeListener() { + public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { + backgroundHandler.post(() -> classifier.setUseNNAPI(isChecked)); + } + }); np = (NumberPicker) view.findViewById(R.id.np); np.setMinValue(1); np.setMaxValue(10); np.setWrapSelectorWheel(true); - np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() { - @Override - public void onValueChange(NumberPicker picker, int oldVal, int newVal){ - classifier.setNumThreads(newVal); - } - }); + np.setOnValueChangedListener( + new NumberPicker.OnValueChangeListener() { + @Override + public void onValueChange(NumberPicker picker, int oldVal, int newVal) { + backgroundHandler.post(() -> classifier.setNumThreads(newVal)); + } + }); } /** Load the model and labels. */ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index 7bb6afd9d8..2d11a57434 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -59,9 +59,15 @@ public abstract class ImageClassifier { private static final int DIM_PIXEL_SIZE = 3; - /* Preallocated buffers for storing image data in. */ + /** Preallocated buffers for storing image data in. */ private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** The loaded TensorFlow Lite model. */ + private MappedByteBuffer tfliteModel; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ protected Interpreter tflite; @@ -89,7 +95,8 @@ public abstract class ImageClassifier { /** Initializes an {@code ImageClassifier}. */ ImageClassifier(Activity activity) throws IOException { - tflite = new Interpreter(loadModelFile(activity)); + tfliteModel = loadModelFile(activity); + tflite = new Interpreter(tfliteModel, tfliteOptions); labelList = loadLabelList(activity); imgData = ByteBuffer.allocateDirect( @@ -150,20 +157,28 @@ public abstract class ImageClassifier { } } + private void recreateInterpreter() { + if (tflite != null) { + tflite.close(); + tflite = new Interpreter(tfliteModel, tfliteOptions); + } + } + public void setUseNNAPI(Boolean nnapi) { - if (tflite != null) - tflite.setUseNNAPI(nnapi); + tfliteOptions.setUseNNAPI(nnapi); + recreateInterpreter(); } - public void setNumThreads(int num_threads) { - if (tflite != null) - tflite.setNumThreads(num_threads); + public void setNumThreads(int numThreads) { + tfliteOptions.setNumThreads(numThreads); + recreateInterpreter(); } /** Closes tflite to release resources. */ public void close() { tflite.close(); tflite = null; + tfliteModel = null; } /** Reads label list from Assets. */ diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java index 4cf51bb0fa..fd610b054f 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java @@ -74,7 +74,7 @@ public class OvicClassifier { } labelList = loadLabelList(labelInputStream); // OVIC uses one thread for CPU inference. - tflite = new Interpreter(model, 1); + tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1)); inputDims = TestHelper.getInputDims(tflite, 0); if (inputDims.length != 4) { throw new RuntimeException("The model's input dimensions must be 4 (BWHC)."); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index b84720ae8e..ffb04496cb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -17,7 +17,6 @@ package org.tensorflow.lite; import java.io.File; import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; import org.checkerframework.checker.nullness.qual.NonNull; @@ -56,16 +55,36 @@ import org.checkerframework.checker.nullness.qual.NonNull; */ public final class Interpreter implements AutoCloseable { + /** An options class for controlling runtime interpreter behavior. */ + public static class Options { + public Options() {} + + /** + * Sets the number of threads to be used for ops that support multi-threading. Defaults to a + * platform-dependent value. + */ + public Options setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ + public Options setUseNNAPI(boolean useNNAPI) { + this.useNNAPI = useNNAPI; + return this; + } + + int numThreads = -1; + boolean useNNAPI = false; + } + /** * Initializes a {@code Interpreter} * * @param modelFile: a File of a pre-trained TF Lite model. */ public Interpreter(@NonNull File modelFile) { - if (modelFile == null) { - return; - } - wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + this(modelFile, /*options = */ null); } /** @@ -73,12 +92,22 @@ public final class Interpreter implements AutoCloseable { * * @param modelFile: a file of a pre-trained TF Lite model * @param numThreads: number of threads to use for inference + * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will + * be removed in a future release. */ + @Deprecated public Interpreter(@NonNull File modelFile, int numThreads) { - if (modelFile == null) { - return; - } - wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads); + this(modelFile, new Options().setNumThreads(numThreads)); + } + + /** + * Initializes a {@code Interpreter} and specifies the number of threads used for inference. + * + * @param modelFile: a file of a pre-trained TF Lite model + * @param options: a set of options for customizing interpreter behavior + */ + public Interpreter(@NonNull File modelFile, Options options) { + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options); } /** @@ -89,7 +118,7 @@ public final class Interpreter implements AutoCloseable { * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ public Interpreter(@NonNull ByteBuffer byteBuffer) { - wrapper = new NativeInterpreterWrapper(byteBuffer); + this(byteBuffer, /* options= */ null); } /** @@ -99,30 +128,25 @@ public final class Interpreter implements AutoCloseable { * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. - */ - public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { - wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); - } - - /** - * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. * - * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code - * Interpreter}. + * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method + * will be removed in a future release. */ - public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { - wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + @Deprecated + public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { + this(byteBuffer, new Options().setNumThreads(numThreads)); } /** - * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and - * specifies the number of threads used for inference. + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom + * {@link #Options}. * - * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code - * Interpreter}. + * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ - public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { - wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) { + wrapper = new NativeInterpreterWrapper(byteBuffer, options); } /** @@ -240,12 +264,25 @@ public final class Interpreter implements AutoCloseable { return wrapper.getLastNativeInferenceDurationNanoseconds(); } - /** Turns on/off Android NNAPI for hardware acceleration when it is available. */ + /** + * Turns on/off Android NNAPI for hardware acceleration when it is available. + * + * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API. + * This method will be removed in a future release. + */ + @Deprecated public void setUseNNAPI(boolean useNNAPI) { checkNotClosed(); wrapper.setUseNNAPI(useNNAPI); } + /** + * Sets the number of threads to be used for ops that support multi-threading. + * + * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread + * multi-threading. This method will be removed in a future release. + */ + @Deprecated public void setNumThreads(int numThreads) { checkNotClosed(); wrapper.setNumThreads(numThreads); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index fa25082304..6feff9a618 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -23,7 +23,7 @@ import java.util.HashMap; import java.util.Map; /** - * A wrapper wraps native interpreter and controls model execution. + * An internal wrapper that wraps native interpreter and controls model execution. * * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be * explicitly freed by invoking the {@link #close()} method when the {@code @@ -32,36 +32,29 @@ import java.util.Map; final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { - this(modelPath, /* numThreads= */ -1); + this(modelPath, /* options= */ null); } - NativeInterpreterWrapper(String modelPath, int numThreads) { + NativeInterpreterWrapper(String modelPath, Interpreter.Options options) { + if (options == null) { + options = new Interpreter.Options(); + } errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); + interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads); isMemoryAllocated = true; inputTensors = new Tensor[getInputCount(interpreterHandle)]; outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } - /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should - * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code - * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct - * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. - */ NativeInterpreterWrapper(ByteBuffer byteBuffer) { - this(byteBuffer, /* numThreads= */ -1); + this(byteBuffer, /* options= */ null); } - /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the - * number of inference threads. The ByteBuffer should not be modified after the construction of a - * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code - * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of - * nativeOrder() that contains the bytes content of a model. - */ - NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) { + NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) { + if (options == null) { + options = new Interpreter.Options(); + } if (buffer == null || (!(buffer instanceof MappedByteBuffer) && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { @@ -72,10 +65,13 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = buffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); + interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads); isMemoryAllocated = true; inputTensors = new Tensor[getInputCount(interpreterHandle)]; outputTensors = new Tensor[getOutputCount(interpreterHandle)]; + if (options.useNNAPI) { + setUseNNAPI(options.useNNAPI); + } } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 9070b788b6..dfdd7d22b0 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -55,6 +55,18 @@ public final class InterpreterTest { } @Test + public void testInterpreterWithOptions() throws Exception { + Interpreter interpreter = + new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true)); + assertThat(interpreter).isNotNull(); + assertThat(interpreter.getInputTensorCount()).isEqualTo(1); + assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + assertThat(interpreter.getOutputTensorCount()).isEqualTo(1); + assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + interpreter.close(); + } + + @Test public void testRunWithMappedByteBufferModel() throws Exception { Path path = MODEL_FILE.toPath(); FileChannel fileChannel = @@ -304,40 +316,14 @@ public final class InterpreterTest { } @Test - public void testTurnOffNNAPI() throws Exception { - Path path = MODEL_FILE.toPath(); - FileChannel fileChannel = - (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); - MappedByteBuffer mappedByteBuffer = - fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); - Interpreter interpreter = new Interpreter(mappedByteBuffer); - interpreter.setUseNNAPI(true); - float[] oneD = {1.23f, 6.54f, 7.81f}; - float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; - float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; - float[][][][] fourD = {threeD, threeD}; - float[][][][] parsedOutputs = new float[2][8][8][3]; - interpreter.run(fourD, parsedOutputs); - float[] outputOneD = parsedOutputs[0][0][0]; - float[] expected = {3.69f, 19.62f, 23.43f}; - assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); - interpreter.setUseNNAPI(false); - interpreter.run(fourD, parsedOutputs); - outputOneD = parsedOutputs[0][0][0]; - assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); - interpreter.close(); - fileChannel.close(); - } - - @Test public void testTurnOnNNAPI() throws Exception { Path path = MODEL_FILE.toPath(); FileChannel fileChannel = (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); - Interpreter interpreter = new Interpreter(mappedByteBuffer); - interpreter.setUseNNAPI(true); + Interpreter interpreter = + new Interpreter(mappedByteBuffer, new Interpreter.Options().setUseNNAPI(true)); float[] oneD = {1.23f, 6.54f, 7.81f}; float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 9c4a5acd79..270bd6703a 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest { } @Test + public void testConstructorWithOptions() { + NativeInterpreterWrapper wrapper = + new NativeInterpreterWrapper( + FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true)); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test public void testConstructorWithInvalidModel() { try { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 38b740021b..af20e3280b 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -19,21 +19,6 @@ package org.tensorflow.lite; public class TestHelper { /** - * Turns on/off NNAPI of an {@code Interpreter}. - * - * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code - * IllegalArgumentException} will be thrown. - * @param useNNAPI a boolean value indicating to turn on or off NNAPI. - */ - public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) { - if (interpreter != null && interpreter.wrapper != null) { - interpreter.wrapper.setUseNNAPI(useNNAPI); - } else { - throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); - } - } - - /** * Gets the last inference duration in nanoseconds. It returns null if there is no previous * inference run or the last inference run failed. * |