aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 18:07:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 18:10:49 -0700
commitf3785197b4de9466b48462f4f93b455c88dd622b (patch)
tree26259e30dc0a88f400fa4f15f246959b9080c63b /tensorflow/contrib/lite/examples
parentc290930ec1beacbcac414b43b3367dd44ffbd303 (diff)
TFLite Java app for object detection model
PiperOrigin-RevId: 202736707
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java13
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java210
2 files changed, 78 insertions, 145 deletions
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
index de997e454a..87160f6b3f 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
- private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite";
+ private static final boolean TF_OD_API_IS_QUANTIZED = true;
+ private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
-
+
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints.
private enum DetectorMode {
@@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
try {
detector =
TFLiteObjectDetectionAPIModel.create(
- getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
+ getAssets(),
+ TF_OD_API_MODEL_FILE,
+ TF_OD_API_LABELS_FILE,
+ TF_OD_API_INPUT_SIZE,
+ TF_OD_API_IS_QUANTIZED);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
LOGGER.e("Exception initializing classifier!", e);
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
index 580206943b..9eb21de9d0 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
@@ -30,12 +30,9 @@ import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.PriorityQueue;
-import java.util.StringTokenizer;
import java.util.Vector;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.lite.Interpreter;
@@ -48,40 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
- private static final int NUM_RESULTS = 1917;
- private static final int NUM_CLASSES = 91;
-
- private static final float Y_SCALE = 10.0f;
- private static final float X_SCALE = 10.0f;
- private static final float H_SCALE = 5.0f;
- private static final float W_SCALE = 5.0f;
-
+ private static final int NUM_DETECTIONS = 10;
+ private boolean isModelQuantized;
// Float model
private static final float IMAGE_MEAN = 128.0f;
private static final float IMAGE_STD = 128.0f;
-
- //Number of threads in the java app
+ // Number of threads in the java app
private static final int NUM_THREADS = 4;
-
-
// Config values.
private int inputSize;
-
- private final float[][] boxPriors = new float[4][NUM_RESULTS];
-
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
+ // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
+ // contains the location of detected boxes
private float[][][] outputLocations;
- private float[][][] outputClasses;
-
- private ByteBuffer imgData = null;
+ // outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the classes of detected boxes
+ private float[][] outputClasses;
+ // outputScores: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the scores of detected boxes
+ private float[][] outputScores;
+ // numDetections: array of shape [Batchsize]
+ // contains the number of detected boxes
+ private float[] numDetections;
+
+ private ByteBuffer imgData;
private Interpreter tfLite;
- private float expit(final float x) {
- return (float) (1. / (1. + Math.exp(-x)));
- }
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
@@ -94,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
- private void loadCoderOptions(
- final AssetManager assetManager, final String locationFilename, final float[][] boxPriors)
- throws IOException {
- // Try to be intelligent about opening from assets or sdcard depending on prefix.
- final String assetPrefix = "file:///android_asset/";
- InputStream is;
- if (locationFilename.startsWith(assetPrefix)) {
- is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]);
- } else {
- is = new FileInputStream(locationFilename);
- }
-
- final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
-
- for (int lineNum = 0; lineNum < 4; ++lineNum) {
- String line = reader.readLine();
- final StringTokenizer st = new StringTokenizer(line, ", ");
- int priorIndex = 0;
- while (st.hasMoreTokens()) {
- final String token = st.nextToken();
- try {
- final float number = Float.parseFloat(token);
- boxPriors[lineNum][priorIndex++] = number;
- } catch (final NumberFormatException e) {
- // Silently ignore.
- }
- }
- if (priorIndex != NUM_RESULTS) {
- throw new RuntimeException(
- "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS);
- }
- }
-
- LOGGER.i("Loaded box priors!");
- }
-
- void decodeCenterSizeBoxes(float[][][] predictions) {
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i];
- float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i];
- float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i];
- float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i];
-
- float ymin = ycenter - h / 2.f;
- float xmin = xcenter - w / 2.f;
- float ymax = ycenter + h / 2.f;
- float xmax = xcenter + w / 2.f;
-
- predictions[0][i][0] = ymin;
- predictions[0][i][1] = xmin;
- predictions[0][i][2] = ymax;
- predictions[0][i][3] = xmax;
- }
- }
-
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
+ * @param inputSize The size of image input
+ * @param isQuantized Boolean representing model is quantized or not
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
- final int inputSize) throws IOException {
+ final int inputSize,
+ final boolean isQuantized)
+ throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
- d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors);
-
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
@@ -185,15 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
throw new RuntimeException(e);
}
+ d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
- int numBytesPerChannel = 4; // Floating point
+ int numBytesPerChannel;
+ if (isQuantized) {
+ numBytesPerChannel = 1; // Quantized
+ } else {
+ numBytesPerChannel = 4; // Floating point
+ }
d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
d.tfLite.setNumThreads(NUM_THREADS);
- d.outputLocations = new float[1][NUM_RESULTS][4];
- d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+ d.outputLocations = new float[1][NUM_DETECTIONS][4];
+ d.outputClasses = new float[1][NUM_DETECTIONS];
+ d.outputScores = new float[1][NUM_DETECTIONS];
+ d.numDetections = new float[1];
return d;
}
@@ -209,26 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
- // Float model
- imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
- imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
- imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ if (isModelQuantized) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else { // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
- outputLocations = new float[1][NUM_RESULTS][4];
- outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+ outputLocations = new float[1][NUM_DETECTIONS][4];
+ outputClasses = new float[1][NUM_DETECTIONS];
+ outputScores = new float[1][NUM_DETECTIONS];
+ numDetections = new float[1];
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
+ outputMap.put(2, outputScores);
+ outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
@@ -236,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
- decodeCenterSizeBoxes(outputLocations);
-
- // Find the best detections.
- final PriorityQueue<Recognition> pq =
- new PriorityQueue<Recognition>(
- 1,
- new Comparator<Recognition>() {
- @Override
- public int compare(final Recognition lhs, final Recognition rhs) {
- // Intentionally reversed to put high confidence at the head of the queue.
- return Float.compare(rhs.getConfidence(), lhs.getConfidence());
- }
- });
-
- // Scale them back to the input size.
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float topClassScore = -1000f;
- int topClassScoreIndex = -1;
-
- // Skip the first catch-all class.
- for (int j = 1; j < NUM_CLASSES; ++j) {
- float score = expit(outputClasses[0][i][j]);
-
- if (score > topClassScore) {
- topClassScoreIndex = j;
- topClassScore = score;
- }
- }
-
- if (topClassScore > 0.001f) {
- final RectF detection =
- new RectF(
- outputLocations[0][i][1] * inputSize,
- outputLocations[0][i][0] * inputSize,
- outputLocations[0][i][3] * inputSize,
- outputLocations[0][i][2] * inputSize);
-
- pq.add(
- new Recognition(
- "" + i,
- labels.get(topClassScoreIndex),
- outputClasses[0][i][topClassScoreIndex],
- detection));
- }
- }
-
- final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
- for (int i = 0; i < Math.min(pq.size(), 10); ++i) {
- Recognition recog = pq.poll();
- recognitions.add(recog);
+ // Show the best detections.
+ // after scaling them back to the input size.
+ final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS);
+ for (int i = 0; i < NUM_DETECTIONS; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[0][i][1] * inputSize,
+ outputLocations[0][i][0] * inputSize,
+ outputLocations[0][i][3] * inputSize,
+ outputLocations[0][i][2] * inputSize);
+ // SSD Mobilenet V1 Model assumes class 0 is background class
+ // in label file and class labels start from 1 to number_of_classes+1,
+ // while outputClasses correspond to class index from 0 to number_of_classes
+ int labelOffset = 1;
+ recognitions.add(
+ new Recognition(
+ "" + i,
+ labels.get((int) outputClasses[0][i] + labelOffset),
+ outputScores[0][i],
+ detection));
}
Trace.endSection(); // "recognizeImage"
return recognitions;