diff options
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java')
-rw-r--r-- | tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java | 381 |
1 files changed, 381 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java new file mode 100644 index 0000000000..24e5cb57df --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -0,0 +1,381 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo.tracking; + +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import android.graphics.Paint.Cap; +import android.graphics.Paint.Join; +import android.graphics.Paint.Style; +import android.graphics.RectF; +import android.util.DisplayMetrics; +import android.util.Pair; +import android.util.TypedValue; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +import org.tensorflow.demo.Classifier.Recognition; +import org.tensorflow.demo.env.BorderedText; +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; + +/** + * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing + * objects to new detections. + */ +public class MultiBoxTracker { + private final Logger logger = new Logger(); + + private static final float TEXT_SIZE_DIP = 18; + + // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise + // the lower scored box (new or old) will be removed. + private static final float MAX_OVERLAP = 0.35f; + + private static final float MIN_SIZE = 16.0f; + + // Allow replacement of the tracked box with new results if + // correlation has dropped below this level. + private static final float MARGINAL_CORRELATION = 0.75f; + + // Consider object to be lost if correlation falls below this threshold. + private static final float MIN_CORRELATION = 0.3f; + + private static final int[] COLORS = { + Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA + }; + + private final Queue<Integer> availableColors = new LinkedList<Integer>(); + + public ObjectTracker objectTracker; + + final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>(); + + private static class TrackedRecognition { + ObjectTracker.TrackedObject trackedObject; + float detectionConfidence; + int color; + } + + private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>(); + + private final Paint boxPaint = new Paint(); + + private final float textSizePx; + private final BorderedText borderedText; + + private Matrix frameToCanvasMatrix; + + private int frameWidth; + private int frameHeight; + + private int sensorOrientation; + + public MultiBoxTracker(final DisplayMetrics metrics) { + for (final int color : COLORS) { + availableColors.add(color); + } + + boxPaint.setColor(Color.RED); + boxPaint.setStyle(Style.STROKE); + boxPaint.setStrokeWidth(12.0f); + boxPaint.setStrokeCap(Cap.ROUND); + boxPaint.setStrokeJoin(Join.ROUND); + boxPaint.setStrokeMiter(100); + + textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics); + borderedText = new BorderedText(textSizePx); + } + + private Matrix getFrameToCanvasMatrix() { + return frameToCanvasMatrix; + } + + public synchronized void drawDebug(final Canvas canvas) { + final Paint textPaint = new Paint(); + textPaint.setColor(Color.WHITE); + textPaint.setTextSize(60.0f); + + final Paint boxPaint = new Paint(); + boxPaint.setColor(Color.RED); + boxPaint.setAlpha(200); + boxPaint.setStyle(Style.STROKE); + + for (final Pair<Float, RectF> detection : screenRects) { + final RectF rect = detection.second; + canvas.drawRect(rect, boxPaint); + canvas.drawText("" + detection.first, rect.left, rect.top, textPaint); + borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first); + } + + if (objectTracker == null) { + return; + } + + // Draw correlations. + for (final TrackedRecognition recognition : trackedObjects) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); + + if (getFrameToCanvasMatrix().mapRect(trackedPos)) { + final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation()); + borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString); + } + } + + final Matrix matrix = getFrameToCanvasMatrix(); + objectTracker.drawDebug(canvas, matrix); + } + + public synchronized void trackResults( + final List<Recognition> results, final byte[] frame, final long timestamp) { + logger.i("Processing %d results from %d", results.size(), timestamp); + processResults(timestamp, results, frame); + } + + public synchronized void draw(final Canvas canvas) { + if (objectTracker == null) { + return; + } + + // TODO(andrewharp): This may not work for non-90 deg rotations. + final float multiplier = + Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth); + frameToCanvasMatrix = + ImageUtils.getTransformationMatrix( + frameWidth, + frameHeight, + (int) (multiplier * frameHeight), + (int) (multiplier * frameWidth), + sensorOrientation, + false); + + for (final TrackedRecognition recognition : trackedObjects) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); + + if (getFrameToCanvasMatrix().mapRect(trackedPos)) { + boxPaint.setColor(recognition.color); + + final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f; + canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint); + + final String labelString = String.format("%.2f", recognition.detectionConfidence); + borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString); + } + } + } + + public synchronized void onFrame( + final int w, + final int h, + final int rowStride, + final int sensorOrienation, + final byte[] frame, + final long timestamp) { + if (objectTracker == null) { + ObjectTracker.clearInstance(); + + logger.i("Initializing ObjectTracker: %dx%d", w, h); + objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); + frameWidth = w; + frameHeight = h; + this.sensorOrientation = sensorOrienation; + } + + objectTracker.nextFrame(frame, null, timestamp, null, true); + + // Clean up any objects not worth tracking any more. + final LinkedList<TrackedRecognition> copyList = + new LinkedList<TrackedRecognition>(trackedObjects); + for (final TrackedRecognition recognition : copyList) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + final float correlation = trackedObject.getCurrentCorrelation(); + if (correlation < MIN_CORRELATION) { + logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation); + trackedObject.stopTracking(); + trackedObjects.remove(recognition); + + availableColors.add(recognition.color); + } + } + } + + private void processResults( + final long timestamp, final List<Recognition> results, final byte[] originalFrame) { + final List<Pair<Float, RectF>> rectsToTrack = new LinkedList<Pair<Float, RectF>>(); + + screenRects.clear(); + final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix()); + + for (final Recognition result : results) { + if (result.getLocation() == null) { + continue; + } + final RectF detectionFrameRect = new RectF(result.getLocation()); + + final RectF detectionScreenRect = new RectF(); + rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect); + + logger.v( + "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect); + + screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect)); + + if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) { + logger.w("Degenerate rectangle! " + detectionFrameRect); + continue; + } + + rectsToTrack.add(new Pair<Float, RectF>(result.getConfidence(), detectionFrameRect)); + } + + if (rectsToTrack.isEmpty()) { + logger.v("Nothing to track, aborting."); + return; + } + + if (objectTracker == null) { + logger.w("No ObjectTracker, can't track anything!"); + return; + } + + logger.i("%d rects to track", rectsToTrack.size()); + for (final Pair<Float, RectF> potential : rectsToTrack) { + handleDetection(originalFrame, timestamp, potential); + } + } + + private void handleDetection( + final byte[] frameCopy, final long timestamp, final Pair<Float, RectF> potential) { + final ObjectTracker.TrackedObject potentialObject = + objectTracker.trackObject(potential.second, timestamp, frameCopy); + + final float potentialCorrelation = potentialObject.getCurrentCorrelation(); + logger.v( + "Tracked object went from %s to %s with correlation %.2f", + potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation); + + if (potentialCorrelation < MARGINAL_CORRELATION) { + logger.v("Correlation too low to begin tracking %s.", potentialObject); + potentialObject.stopTracking(); + return; + } + + final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>(); + + float maxIntersect = 0.0f; + + // This is the current tracked object whose color we will take. If left null we'll take the + // first one from the color queue. + TrackedRecognition recogToReplace = null; + + // Look for intersections that will be overridden by this object or an intersection that would + // prevent this one from being placed. + for (final TrackedRecognition trackedRecognition : trackedObjects) { + final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame(); + final RectF b = potentialObject.getTrackedPositionInPreviewFrame(); + final RectF intersection = new RectF(); + final boolean intersects = intersection.setIntersect(a, b); + + final float intersectAmount = + intersection.width() + * intersection.height() + / Math.min(a.width() * a.height(), b.width() * b.height()); + + // If there is an intersection with this currently tracked box above the maximum overlap + // percentage allowed, either the new recognition needs to be dismissed or the old + // recognition needs to be removed and possibly replaced with the new one. + if (intersects && intersectAmount > MAX_OVERLAP) { + if (potential.first < trackedRecognition.detectionConfidence + && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) { + // If track for the existing object is still going strong and the detection score was + // good, reject this new object. + potentialObject.stopTracking(); + return; + } else { + removeList.add(trackedRecognition); + + // Let the previously tracked object with max intersection amount donate its color to + // the new object. + if (intersectAmount > maxIntersect) { + maxIntersect = intersectAmount; + recogToReplace = trackedRecognition; + } + } + } + } + + // If we're already tracking the max object and no intersections were found to bump off, + // pick the worst current tracked object to remove, if it's also worse than this candidate + // object. + if (availableColors.isEmpty() && removeList.isEmpty()) { + for (final TrackedRecognition candidate : trackedObjects) { + if (candidate.detectionConfidence < potential.first) { + if (recogToReplace == null + || candidate.detectionConfidence < recogToReplace.detectionConfidence) { + // Save it so that we use this color for the new object. + recogToReplace = candidate; + } + } + } + if (recogToReplace != null) { + logger.v("Found non-intersecting object to remove."); + removeList.add(recogToReplace); + } else { + logger.v("No non-intersecting object found to remove"); + } + } + + // Remove everything that got intersected. + for (final TrackedRecognition trackedRecognition : removeList) { + logger.v( + "Removing tracked object %s with detection confidence %.2f, correlation %.2f", + trackedRecognition.trackedObject, + trackedRecognition.detectionConfidence, + trackedRecognition.trackedObject.getCurrentCorrelation()); + trackedRecognition.trackedObject.stopTracking(); + trackedObjects.remove(trackedRecognition); + if (trackedRecognition != recogToReplace) { + availableColors.add(trackedRecognition.color); + } + } + + if (recogToReplace == null && availableColors.isEmpty()) { + logger.e("No room to track this object, aborting."); + potentialObject.stopTracking(); + return; + } + + // Finally safe to say we can track this object. + logger.v( + "Tracking object %s with detection confidence %.2f at position %s", + potentialObject, potential.first, potential.second); + final TrackedRecognition trackedRecognition = new TrackedRecognition(); + trackedRecognition.detectionConfidence = potential.first; + trackedRecognition.trackedObject = potentialObject; + + // Use the color from a replaced object before taking one from the color queue. + trackedRecognition.color = + recogToReplace != null ? recogToReplace.color : availableColors.poll(); + trackedObjects.add(trackedRecognition); + } +} |