diff options
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java')
-rw-r--r-- | tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java new file mode 100644 index 0000000000..687318c7ce --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java @@ -0,0 +1,219 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.Trace; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Vector; +import org.tensorflow.Graph; +import org.tensorflow.Operation; +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; +import org.tensorflow.demo.env.Logger; + +/** + * Wrapper for frozen detection models trained using the Tensorflow Object Detection API: + * github.com/tensorflow/models/tree/master/object_detection + */ +public class TensorFlowObjectDetectionAPIModel implements Classifier { + private static final Logger LOGGER = new Logger(); + + // Only return this many results. + private static final int MAX_RESULTS = 100; + + // Config values. + private String inputName; + private int inputSize; + + // Pre-allocated buffers. + private Vector<String> labels = new Vector<String>(); + private int[] intValues; + private byte[] byteValues; + private float[] outputLocations; + private float[] outputScores; + private float[] outputClasses; + private float[] outputNumDetections; + private String[] outputNames; + + private boolean logStats = false; + + private TensorFlowInferenceInterface inferenceInterface; + + /** + * Initializes a native TensorFlow session for classifying images. + * + * @param assetManager The asset manager to be used to load assets. + * @param modelFilename The filepath of the model GraphDef protocol buffer. + * @param labelFilename The filepath of label file for classes. + */ + public static Classifier create( + final AssetManager assetManager, + final String modelFilename, + final String labelFilename, + final int inputSize) throws IOException { + final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel(); + + InputStream labelsInput = null; + String actualFilename = labelFilename.split("file:///android_asset/")[1]; + labelsInput = assetManager.open(actualFilename); + BufferedReader br = null; + br = new BufferedReader(new InputStreamReader(labelsInput)); + String line; + while ((line = br.readLine()) != null) { + LOGGER.w(line); + d.labels.add(line); + } + br.close(); + + + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); + + final Graph g = d.inferenceInterface.graph(); + + d.inputName = "image_tensor"; + // The inputName node has a shape of [N, H, W, C], where + // N is the batch size + // H = W are the height and width + // C is the number of channels (3 for our purposes - RGB) + final Operation inputOp = g.operation(d.inputName); + if (inputOp == null) { + throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); + } + d.inputSize = inputSize; + // The outputScoresName node has a shape of [N, NumLocations], where N + // is the batch size. + final Operation outputOp1 = g.operation("detection_scores"); + if (outputOp1 == null) { + throw new RuntimeException("Failed to find output Node 'detection_scores'"); + } + final Operation outputOp2 = g.operation("detection_boxes"); + if (outputOp2 == null) { + throw new RuntimeException("Failed to find output Node 'detection_boxes'"); + } + final Operation outputOp3 = g.operation("detection_classes"); + if (outputOp3 == null) { + throw new RuntimeException("Failed to find output Node 'detection_classes'"); + } + + // Pre-allocate buffers. + d.outputNames = new String[] {"detection_boxes", "detection_scores", + "detection_classes", "num_detections"}; + d.intValues = new int[d.inputSize * d.inputSize]; + d.byteValues = new byte[d.inputSize * d.inputSize * 3]; + d.outputScores = new float[MAX_RESULTS]; + d.outputLocations = new float[MAX_RESULTS * 4]; + d.outputClasses = new float[MAX_RESULTS]; + d.outputNumDetections = new float[1]; + return d; + } + + private TensorFlowObjectDetectionAPIModel() {} + + @Override + public List<Recognition> recognizeImage(final Bitmap bitmap) { + // Log this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("preprocessBitmap"); + // Preprocess the image data from 0-255 int to normalized float based + // on the provided parameters. + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + + for (int i = 0; i < intValues.length; ++i) { + byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); + byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); + byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); + } + Trace.endSection(); // preprocessBitmap + + // Copy the input data into TensorFlow. + Trace.beginSection("feed"); + inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); + Trace.endSection(); + + // Run the inference call. + Trace.beginSection("run"); + inferenceInterface.run(outputNames, logStats); + Trace.endSection(); + + // Copy the output Tensor back into the output array. + Trace.beginSection("fetch"); + outputLocations = new float[MAX_RESULTS * 4]; + outputScores = new float[MAX_RESULTS]; + outputClasses = new float[MAX_RESULTS]; + outputNumDetections = new float[1]; + inferenceInterface.fetch(outputNames[0], outputLocations); + inferenceInterface.fetch(outputNames[1], outputScores); + inferenceInterface.fetch(outputNames[2], outputClasses); + inferenceInterface.fetch(outputNames[3], outputNumDetections); + Trace.endSection(); + + // Find the best detections. + final PriorityQueue<Recognition> pq = + new PriorityQueue<Recognition>( + 1, + new Comparator<Recognition>() { + @Override + public int compare(final Recognition lhs, final Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + // Scale them back to the input size. + for (int i = 0; i < outputScores.length; ++i) { + final RectF detection = + new RectF( + outputLocations[4 * i + 1] * inputSize, + outputLocations[4 * i] * inputSize, + outputLocations[4 * i + 3] * inputSize, + outputLocations[4 * i + 2] * inputSize); + pq.add( + new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection)); + } + + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { + recognitions.add(pq.poll()); + } + Trace.endSection(); // "recognizeImage" + return recognitions; + } + + @Override + public void enableStatLogging(final boolean logStats) { + this.logStats = logStats; + } + + @Override + public String getStatString() { + return inferenceInterface.getStatString(); + } + + @Override + public void close() { + inferenceInterface.close(); + } +} |