aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java')
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java381
1 files changed, 381 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
new file mode 100644
index 0000000000..24e5cb57df
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
@@ -0,0 +1,381 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Cap;
+import android.graphics.Paint.Join;
+import android.graphics.Paint.Style;
+import android.graphics.RectF;
+import android.util.DisplayMetrics;
+import android.util.Pair;
+import android.util.TypedValue;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+
+import org.tensorflow.demo.Classifier.Recognition;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
+ * objects to new detections.
+ */
+public class MultiBoxTracker {
+ private final Logger logger = new Logger();
+
+ private static final float TEXT_SIZE_DIP = 18;
+
+ // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
+ // the lower scored box (new or old) will be removed.
+ private static final float MAX_OVERLAP = 0.35f;
+
+ private static final float MIN_SIZE = 16.0f;
+
+ // Allow replacement of the tracked box with new results if
+ // correlation has dropped below this level.
+ private static final float MARGINAL_CORRELATION = 0.75f;
+
+ // Consider object to be lost if correlation falls below this threshold.
+ private static final float MIN_CORRELATION = 0.3f;
+
+ private static final int[] COLORS = {
+ Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA
+ };
+
+ private final Queue<Integer> availableColors = new LinkedList<Integer>();
+
+ public ObjectTracker objectTracker;
+
+ final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
+
+ private static class TrackedRecognition {
+ ObjectTracker.TrackedObject trackedObject;
+ float detectionConfidence;
+ int color;
+ }
+
+ private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
+
+ private final Paint boxPaint = new Paint();
+
+ private final float textSizePx;
+ private final BorderedText borderedText;
+
+ private Matrix frameToCanvasMatrix;
+
+ private int frameWidth;
+ private int frameHeight;
+
+ private int sensorOrientation;
+
+ public MultiBoxTracker(final DisplayMetrics metrics) {
+ for (final int color : COLORS) {
+ availableColors.add(color);
+ }
+
+ boxPaint.setColor(Color.RED);
+ boxPaint.setStyle(Style.STROKE);
+ boxPaint.setStrokeWidth(12.0f);
+ boxPaint.setStrokeCap(Cap.ROUND);
+ boxPaint.setStrokeJoin(Join.ROUND);
+ boxPaint.setStrokeMiter(100);
+
+ textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics);
+ borderedText = new BorderedText(textSizePx);
+ }
+
+ private Matrix getFrameToCanvasMatrix() {
+ return frameToCanvasMatrix;
+ }
+
+ public synchronized void drawDebug(final Canvas canvas) {
+ final Paint textPaint = new Paint();
+ textPaint.setColor(Color.WHITE);
+ textPaint.setTextSize(60.0f);
+
+ final Paint boxPaint = new Paint();
+ boxPaint.setColor(Color.RED);
+ boxPaint.setAlpha(200);
+ boxPaint.setStyle(Style.STROKE);
+
+ for (final Pair<Float, RectF> detection : screenRects) {
+ final RectF rect = detection.second;
+ canvas.drawRect(rect, boxPaint);
+ canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
+ borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
+ }
+
+ if (objectTracker == null) {
+ return;
+ }
+
+ // Draw correlations.
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
+ borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
+ }
+ }
+
+ final Matrix matrix = getFrameToCanvasMatrix();
+ objectTracker.drawDebug(canvas, matrix);
+ }
+
+ public synchronized void trackResults(
+ final List<Recognition> results, final byte[] frame, final long timestamp) {
+ logger.i("Processing %d results from %d", results.size(), timestamp);
+ processResults(timestamp, results, frame);
+ }
+
+ public synchronized void draw(final Canvas canvas) {
+ if (objectTracker == null) {
+ return;
+ }
+
+ // TODO(andrewharp): This may not work for non-90 deg rotations.
+ final float multiplier =
+ Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth);
+ frameToCanvasMatrix =
+ ImageUtils.getTransformationMatrix(
+ frameWidth,
+ frameHeight,
+ (int) (multiplier * frameHeight),
+ (int) (multiplier * frameWidth),
+ sensorOrientation,
+ false);
+
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ boxPaint.setColor(recognition.color);
+
+ final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
+ canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
+
+ final String labelString = String.format("%.2f", recognition.detectionConfidence);
+ borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
+ }
+ }
+ }
+
+ public synchronized void onFrame(
+ final int w,
+ final int h,
+ final int rowStride,
+ final int sensorOrienation,
+ final byte[] frame,
+ final long timestamp) {
+ if (objectTracker == null) {
+ ObjectTracker.clearInstance();
+
+ logger.i("Initializing ObjectTracker: %dx%d", w, h);
+ objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
+ frameWidth = w;
+ frameHeight = h;
+ this.sensorOrientation = sensorOrienation;
+ }
+
+ objectTracker.nextFrame(frame, null, timestamp, null, true);
+
+ // Clean up any objects not worth tracking any more.
+ final LinkedList<TrackedRecognition> copyList =
+ new LinkedList<TrackedRecognition>(trackedObjects);
+ for (final TrackedRecognition recognition : copyList) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+ final float correlation = trackedObject.getCurrentCorrelation();
+ if (correlation < MIN_CORRELATION) {
+ logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
+ trackedObject.stopTracking();
+ trackedObjects.remove(recognition);
+
+ availableColors.add(recognition.color);
+ }
+ }
+ }
+
+ private void processResults(
+ final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
+ final List<Pair<Float, RectF>> rectsToTrack = new LinkedList<Pair<Float, RectF>>();
+
+ screenRects.clear();
+ final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
+
+ for (final Recognition result : results) {
+ if (result.getLocation() == null) {
+ continue;
+ }
+ final RectF detectionFrameRect = new RectF(result.getLocation());
+
+ final RectF detectionScreenRect = new RectF();
+ rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
+
+ logger.v(
+ "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
+
+ screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
+
+ if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
+ logger.w("Degenerate rectangle! " + detectionFrameRect);
+ continue;
+ }
+
+ rectsToTrack.add(new Pair<Float, RectF>(result.getConfidence(), detectionFrameRect));
+ }
+
+ if (rectsToTrack.isEmpty()) {
+ logger.v("Nothing to track, aborting.");
+ return;
+ }
+
+ if (objectTracker == null) {
+ logger.w("No ObjectTracker, can't track anything!");
+ return;
+ }
+
+ logger.i("%d rects to track", rectsToTrack.size());
+ for (final Pair<Float, RectF> potential : rectsToTrack) {
+ handleDetection(originalFrame, timestamp, potential);
+ }
+ }
+
+ private void handleDetection(
+ final byte[] frameCopy, final long timestamp, final Pair<Float, RectF> potential) {
+ final ObjectTracker.TrackedObject potentialObject =
+ objectTracker.trackObject(potential.second, timestamp, frameCopy);
+
+ final float potentialCorrelation = potentialObject.getCurrentCorrelation();
+ logger.v(
+ "Tracked object went from %s to %s with correlation %.2f",
+ potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
+
+ if (potentialCorrelation < MARGINAL_CORRELATION) {
+ logger.v("Correlation too low to begin tracking %s.", potentialObject);
+ potentialObject.stopTracking();
+ return;
+ }
+
+ final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
+
+ float maxIntersect = 0.0f;
+
+ // This is the current tracked object whose color we will take. If left null we'll take the
+ // first one from the color queue.
+ TrackedRecognition recogToReplace = null;
+
+ // Look for intersections that will be overridden by this object or an intersection that would
+ // prevent this one from being placed.
+ for (final TrackedRecognition trackedRecognition : trackedObjects) {
+ final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
+ final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
+ final RectF intersection = new RectF();
+ final boolean intersects = intersection.setIntersect(a, b);
+
+ final float intersectAmount =
+ intersection.width()
+ * intersection.height()
+ / Math.min(a.width() * a.height(), b.width() * b.height());
+
+ // If there is an intersection with this currently tracked box above the maximum overlap
+ // percentage allowed, either the new recognition needs to be dismissed or the old
+ // recognition needs to be removed and possibly replaced with the new one.
+ if (intersects && intersectAmount > MAX_OVERLAP) {
+ if (potential.first < trackedRecognition.detectionConfidence
+ && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
+ // If track for the existing object is still going strong and the detection score was
+ // good, reject this new object.
+ potentialObject.stopTracking();
+ return;
+ } else {
+ removeList.add(trackedRecognition);
+
+ // Let the previously tracked object with max intersection amount donate its color to
+ // the new object.
+ if (intersectAmount > maxIntersect) {
+ maxIntersect = intersectAmount;
+ recogToReplace = trackedRecognition;
+ }
+ }
+ }
+ }
+
+ // If we're already tracking the max object and no intersections were found to bump off,
+ // pick the worst current tracked object to remove, if it's also worse than this candidate
+ // object.
+ if (availableColors.isEmpty() && removeList.isEmpty()) {
+ for (final TrackedRecognition candidate : trackedObjects) {
+ if (candidate.detectionConfidence < potential.first) {
+ if (recogToReplace == null
+ || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
+ // Save it so that we use this color for the new object.
+ recogToReplace = candidate;
+ }
+ }
+ }
+ if (recogToReplace != null) {
+ logger.v("Found non-intersecting object to remove.");
+ removeList.add(recogToReplace);
+ } else {
+ logger.v("No non-intersecting object found to remove");
+ }
+ }
+
+ // Remove everything that got intersected.
+ for (final TrackedRecognition trackedRecognition : removeList) {
+ logger.v(
+ "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
+ trackedRecognition.trackedObject,
+ trackedRecognition.detectionConfidence,
+ trackedRecognition.trackedObject.getCurrentCorrelation());
+ trackedRecognition.trackedObject.stopTracking();
+ trackedObjects.remove(trackedRecognition);
+ if (trackedRecognition != recogToReplace) {
+ availableColors.add(trackedRecognition.color);
+ }
+ }
+
+ if (recogToReplace == null && availableColors.isEmpty()) {
+ logger.e("No room to track this object, aborting.");
+ potentialObject.stopTracking();
+ return;
+ }
+
+ // Finally safe to say we can track this object.
+ logger.v(
+ "Tracking object %s with detection confidence %.2f at position %s",
+ potentialObject, potential.first, potential.second);
+ final TrackedRecognition trackedRecognition = new TrackedRecognition();
+ trackedRecognition.detectionConfidence = potential.first;
+ trackedRecognition.trackedObject = potentialObject;
+
+ // Use the color from a replaced object before taking one from the color queue.
+ trackedRecognition.color =
+ recogToReplace != null ? recogToReplace.color : availableColors.poll();
+ trackedObjects.add(trackedRecognition);
+ }
+}