diff options
author | 2018-06-28 16:20:10 -0700 | |
---|---|---|
committer | 2018-06-28 16:23:06 -0700 | |
commit | 2273cda3e0209d17fc4f2f055a28d27377b65988 (patch) | |
tree | c7a054d3c4ca34120063a6f05722459b51f45ab4 /tensorflow/contrib/lite/examples | |
parent | 0ea6847c892497afdd20c1150fee1e532612ca17 (diff) |
Change inputs from multi-dimensional arrays to ByteBuffer in TF Lite Object Detection app
PiperOrigin-RevId: 202564164
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r-- | tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04b..580206943b 100644 --- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,6 +25,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; @@ -54,6 +56,14 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final float H_SCALE = 5.0f; private static final float W_SCALE = 5.0f; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + + //Number of threads in the java app + private static final int NUM_THREADS = 4; + + // Config values. private int inputSize; @@ -65,7 +75,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private float[][][] outputLocations; private float[][][] outputClasses; - float[][][][] img; + private ByteBuffer imgData = null; private Interpreter tfLite; @@ -176,9 +186,12 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { } // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel = 4; // Floating point + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; + + d.tfLite.setNumThreads(NUM_THREADS); d.outputLocations = new float[1][NUM_RESULTS][4]; d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; return d; @@ -198,10 +211,11 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); } } Trace.endSection(); // preprocessBitmap @@ -211,7 +225,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { outputLocations = new float[1][NUM_RESULTS][4]; outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map<Integer, Object> outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); |