aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 16:20:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 16:23:06 -0700
commit2273cda3e0209d17fc4f2f055a28d27377b65988 (patch)
treec7a054d3c4ca34120063a6f05722459b51f45ab4 /tensorflow/contrib/lite/examples
parent0ea6847c892497afdd20c1150fee1e532612ca17 (diff)
Change inputs from multi-dimensional arrays to ByteBuffer in TF Lite Object Detection app
PiperOrigin-RevId: 202564164
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java30
1 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
index bfb4a0a04b..580206943b 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
@@ -25,6 +25,8 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
@@ -54,6 +56,14 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final float H_SCALE = 5.0f;
private static final float W_SCALE = 5.0f;
+ // Float model
+ private static final float IMAGE_MEAN = 128.0f;
+ private static final float IMAGE_STD = 128.0f;
+
+ //Number of threads in the java app
+ private static final int NUM_THREADS = 4;
+
+
// Config values.
private int inputSize;
@@ -65,7 +75,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private float[][][] outputLocations;
private float[][][] outputClasses;
- float[][][][] img;
+ private ByteBuffer imgData = null;
private Interpreter tfLite;
@@ -176,9 +186,12 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
}
// Pre-allocate buffers.
- d.img = new float[1][inputSize][inputSize][3];
-
+ int numBytesPerChannel = 4; // Floating point
+ d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
+ d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
+
+ d.tfLite.setNumThreads(NUM_THREADS);
d.outputLocations = new float[1][NUM_RESULTS][4];
d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
return d;
@@ -198,10 +211,11 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
- int pixel = intValues[j * inputSize + i];
- img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f;
+ int pixelValue = intValues[i * inputSize + j];
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
Trace.endSection(); // preprocessBitmap
@@ -211,7 +225,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
outputLocations = new float[1][NUM_RESULTS][4];
outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
- Object[] inputArray = {img};
+ Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);