aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java')
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java649
1 files changed, 649 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
new file mode 100644
index 0000000000..211d8077a3
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
@@ -0,0 +1,649 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.PointF;
+import android.graphics.RectF;
+import android.graphics.Typeface;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+import javax.microedition.khronos.opengles.GL10;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.env.Size;
+
+/**
+ * True object detector/tracker class that tracks objects across consecutive preview frames.
+ * It provides a simplified Java interface to the analogous native object defined by
+ * jni/client_vision/tracking/object_tracker.*.
+ *
+ * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
+ * be allocated by ObjectTracker.getInstance(). In addition, release() should be called
+ * as soon as the ObjectTracker is no longer needed, and before a new one is created.
+ *
+ * nextFrame() should be called as new frames become available, preferably as often as possible.
+ *
+ * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
+ * are associated with the ObjectTracker that created them, and are only valid while that
+ * ObjectTracker still exists.
+ */
+public class ObjectTracker {
+ private final Logger logger = new Logger();
+
+ private static final boolean DRAW_TEXT = false;
+
+ /**
+ * How many history points to keep track of and draw in the red history line.
+ */
+ private static final int MAX_DEBUG_HISTORY_SIZE = 30;
+
+ /**
+ * How many frames of optical flow deltas to record.
+ * TODO(andrewharp): Push this down to the native level so it can be polled
+ * efficiently into a an array for upload, instead of keeping a duplicate
+ * copy in Java.
+ */
+ private static final int MAX_FRAME_HISTORY_SIZE = 200;
+
+ private static final int DOWNSAMPLE_FACTOR = 2;
+
+ private final byte[] downsampledFrame;
+
+ protected static ObjectTracker instance;
+
+ private final Map<String, TrackedObject> trackedObjects;
+
+ private long lastTimestamp;
+
+ private FrameChange lastKeypoints;
+
+ private final Vector<PointF> debugHistory;
+
+ private final LinkedList<TimestampedDeltas> timestampedDeltas;
+
+ protected final int frameWidth;
+ protected final int frameHeight;
+ private final int rowStride;
+ protected final boolean alwaysTrack;
+
+ private static class TimestampedDeltas {
+ final long timestamp;
+ final byte[] deltas;
+
+ public TimestampedDeltas(final long timestamp, final byte[] deltas) {
+ this.timestamp = timestamp;
+ this.deltas = deltas;
+ }
+ }
+
+ /**
+ * A simple class that records keypoint information, which includes
+ * local location, score and type. This will be used in calculating
+ * FrameChange.
+ */
+ public static class Keypoint {
+ public final float x;
+ public final float y;
+ public final float score;
+ public final int type;
+
+ public Keypoint(final float x, final float y) {
+ this.x = x;
+ this.y = y;
+ this.score = 0;
+ this.type = -1;
+ }
+
+ public Keypoint(final float x, final float y, final float score, final int type) {
+ this.x = x;
+ this.y = y;
+ this.score = score;
+ this.type = type;
+ }
+
+ Keypoint delta(final Keypoint other) {
+ return new Keypoint(this.x - other.x, this.y - other.y);
+ }
+ }
+
+ /**
+ * A simple class that could calculate Keypoint delta.
+ * This class will be used in calculating frame translation delta
+ * for optical flow.
+ */
+ public static class PointChange {
+ public final Keypoint keypointA;
+ public final Keypoint keypointB;
+ Keypoint pointDelta;
+ private final boolean wasFound;
+
+ public PointChange(final float x1, final float y1,
+ final float x2, final float y2,
+ final float score, final int type,
+ final boolean wasFound) {
+ this.wasFound = wasFound;
+
+ keypointA = new Keypoint(x1, y1, score, type);
+ keypointB = new Keypoint(x2, y2);
+ }
+
+ public Keypoint getDelta() {
+ if (pointDelta == null) {
+ pointDelta = keypointB.delta(keypointA);
+ }
+ return pointDelta;
+ }
+ }
+
+ /** A class that records a timestamped frame translation delta for optical flow. */
+ public static class FrameChange {
+ public static final int KEYPOINT_STEP = 7;
+
+ public final Vector<PointChange> pointDeltas;
+
+ private final float minScore;
+ private final float maxScore;
+
+ public FrameChange(final float[] framePoints) {
+ float minScore = 100.0f;
+ float maxScore = -100.0f;
+
+ pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
+
+ for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
+ final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
+ final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
+
+ final boolean wasFound = framePoints[i + 2] > 0.0f;
+
+ final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
+ final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
+ final float score = framePoints[i + 5];
+ final int type = (int) framePoints[i + 6];
+
+ minScore = Math.min(minScore, score);
+ maxScore = Math.max(maxScore, score);
+
+ pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
+ }
+
+ this.minScore = minScore;
+ this.maxScore = maxScore;
+ }
+ }
+
+ public static synchronized ObjectTracker getInstance(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ if (instance == null) {
+ instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
+ instance.init();
+ } else {
+ throw new RuntimeException(
+ "Tried to create a new objectracker before releasing the old one!");
+ }
+ return instance;
+ }
+
+ public static synchronized void clearInstance() {
+ if (instance != null) {
+ instance.release();
+ }
+ }
+
+ protected ObjectTracker(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ this.frameWidth = frameWidth;
+ this.frameHeight = frameHeight;
+ this.rowStride = rowStride;
+ this.alwaysTrack = alwaysTrack;
+ this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
+
+ trackedObjects = new HashMap<String, TrackedObject>();
+
+ debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
+
+ downsampledFrame =
+ new byte
+ [(frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR
+ * (frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR];
+ }
+
+ protected void init() {
+ // The native tracker never sees the full frame, so pre-scale dimensions
+ // by the downsample factor.
+ initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
+ }
+
+ private final float[] matrixValues = new float[9];
+
+ private long downsampledTimestamp;
+
+ @SuppressWarnings("unused")
+ public synchronized void drawOverlay(final GL10 gl,
+ final Size cameraViewSize, final Matrix matrix) {
+ final Matrix tempMatrix = new Matrix(matrix);
+ tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
+ tempMatrix.getValues(matrixValues);
+ drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
+ }
+
+ public synchronized void nextFrame(
+ final byte[] frameData, final byte[] uvData,
+ final long timestamp, final float[] transformationMatrix,
+ final boolean updateDebugInfo) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+
+ // Do Lucas Kanade using the fullframe initializer.
+ nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
+
+ timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
+ while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
+ timestampedDeltas.removeFirst();
+ }
+
+ for (final TrackedObject trackedObject : trackedObjects.values()) {
+ trackedObject.updateTrackedPosition();
+ }
+
+ if (updateDebugInfo) {
+ updateDebugHistory();
+ }
+
+ lastTimestamp = timestamp;
+ }
+
+ public synchronized void release() {
+ releaseMemoryNative();
+ synchronized (ObjectTracker.class) {
+ instance = null;
+ }
+ }
+
+ private void drawHistoryDebug(final Canvas canvas) {
+ drawHistoryPoint(
+ canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
+ }
+
+ private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
+ final Paint p = new Paint();
+ p.setAntiAlias(false);
+ p.setTypeface(Typeface.SERIF);
+
+ p.setColor(Color.RED);
+ p.setStrokeWidth(2.0f);
+
+ // Draw the center circle.
+ p.setColor(Color.GREEN);
+ canvas.drawCircle(startX, startY, 3.0f, p);
+
+ p.setColor(Color.RED);
+
+ // Iterate through in backwards order.
+ synchronized (debugHistory) {
+ final int numPoints = debugHistory.size();
+ float lastX = startX;
+ float lastY = startY;
+ for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
+ final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
+ final float newX = lastX + delta.x;
+ final float newY = lastY + delta.y;
+ canvas.drawLine(lastX, lastY, newX, newY, p);
+ lastX = newX;
+ lastY = newY;
+ }
+ }
+ }
+
+ private static int floatToChar(final float value) {
+ return Math.max(0, Math.min((int) (value * 255.999f), 255));
+ }
+
+ private void drawKeypointsDebug(final Canvas canvas) {
+ final Paint p = new Paint();
+ if (lastKeypoints == null) {
+ return;
+ }
+ final int keypointSize = 3;
+
+ final float minScore = lastKeypoints.minScore;
+ final float maxScore = lastKeypoints.maxScore;
+
+ for (final PointChange keypoint : lastKeypoints.pointDeltas) {
+ if (keypoint.wasFound) {
+ final int r =
+ floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
+ final int b =
+ floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
+
+ final int color = 0xFF000000 | (r << 16) | b;
+ p.setColor(color);
+
+ final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
+ keypoint.keypointB.x, keypoint.keypointB.y};
+ canvas.drawRect(screenPoints[2] - keypointSize,
+ screenPoints[3] - keypointSize,
+ screenPoints[2] + keypointSize,
+ screenPoints[3] + keypointSize, p);
+ p.setColor(Color.CYAN);
+ canvas.drawLine(screenPoints[2], screenPoints[3],
+ screenPoints[0], screenPoints[1], p);
+
+ if (DRAW_TEXT) {
+ p.setColor(Color.WHITE);
+ canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
+ keypoint.keypointA.x, keypoint.keypointA.y, p);
+ }
+ } else {
+ p.setColor(Color.YELLOW);
+ final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
+ canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
+ }
+ }
+ }
+
+ private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
+ final float positionY, final float radius) {
+ final RectF currPosition = getCurrentPosition(timestamp,
+ new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
+ return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
+ }
+
+ private synchronized RectF getCurrentPosition(final long timestamp, final RectF
+ oldPosition) {
+ final RectF downscaledFrameRect = downscaleRect(oldPosition);
+
+ final float[] delta = new float[4];
+ getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
+ downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
+
+ final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ return upscaleRect(newPosition);
+ }
+
+ private void updateDebugHistory() {
+ lastKeypoints = new FrameChange(getKeypointsNative(false));
+
+ if (lastTimestamp == 0) {
+ return;
+ }
+
+ final PointF delta =
+ getAccumulatedDelta(
+ lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
+
+ synchronized (debugHistory) {
+ debugHistory.add(delta);
+
+ while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
+ debugHistory.remove(0);
+ }
+ }
+ }
+
+ public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
+ canvas.save();
+ canvas.setMatrix(frameToCanvas);
+
+ drawHistoryDebug(canvas);
+ drawKeypointsDebug(canvas);
+
+ canvas.restore();
+ }
+
+ public Vector<String> getDebugText() {
+ final Vector<String> lines = new Vector<String>();
+
+ if (lastKeypoints != null) {
+ lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
+ lines.add("Min score: " + lastKeypoints.minScore);
+ lines.add("Max score: " + lastKeypoints.maxScore);
+ }
+
+ return lines;
+ }
+
+ public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
+ final List<byte[]> frameDeltas = new ArrayList<byte[]>();
+ while (timestampedDeltas.size() > 0) {
+ final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
+ if (currentDeltas.timestamp <= endFrameTime) {
+ frameDeltas.add(currentDeltas.deltas);
+ timestampedDeltas.removeFirst();
+ } else {
+ break;
+ }
+ }
+
+ return frameDeltas;
+ }
+
+ private RectF downscaleRect(final RectF fullFrameRect) {
+ return new RectF(
+ fullFrameRect.left / DOWNSAMPLE_FACTOR,
+ fullFrameRect.top / DOWNSAMPLE_FACTOR,
+ fullFrameRect.right / DOWNSAMPLE_FACTOR,
+ fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
+ }
+
+ private RectF upscaleRect(final RectF downsampledFrameRect) {
+ return new RectF(
+ downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
+ }
+
+ /**
+ * A TrackedObject represents a native TrackedObject, and provides access to the
+ * relevant native tracking information available after every frame update. They may
+ * be safely passed around and acessed externally, but will become invalid after
+ * stopTracking() is called or the related creating ObjectTracker is deactivated.
+ *
+ * @author andrewharp@google.com (Andrew Harp)
+ */
+ public class TrackedObject {
+ private final String id;
+
+ private long lastExternalPositionTime;
+
+ private RectF lastTrackedPosition;
+ private boolean visibleInLastFrame;
+
+ private boolean isDead;
+
+ TrackedObject(final RectF position, final long timestamp, final byte[] data) {
+ isDead = false;
+
+ id = Integer.toString(this.hashCode());
+
+ lastExternalPositionTime = timestamp;
+
+ synchronized (ObjectTracker.this) {
+ registerInitialAppearance(position, data);
+ setPreviousPosition(position, timestamp);
+ trackedObjects.put(id, this);
+ }
+ }
+
+ public void stopTracking() {
+ checkValidObject();
+
+ synchronized (ObjectTracker.this) {
+ isDead = true;
+ forgetNative(id);
+ trackedObjects.remove(id);
+ }
+ }
+
+ public float getCurrentCorrelation() {
+ checkValidObject();
+ return ObjectTracker.this.getCurrentCorrelation(id);
+ }
+
+ void registerInitialAppearance(final RectF position, final byte[] data) {
+ final RectF externalPosition = downscaleRect(position);
+ registerNewObjectWithAppearanceNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ data);
+ }
+
+ synchronized void setPreviousPosition(final RectF position, final long timestamp) {
+ checkValidObject();
+ synchronized (ObjectTracker.this) {
+ if (lastExternalPositionTime > timestamp) {
+ logger.w("Tried to use older position time!");
+ return;
+ }
+ final RectF externalPosition = downscaleRect(position);
+ lastExternalPositionTime = timestamp;
+
+ setPreviousPositionNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ lastExternalPositionTime);
+
+ updateTrackedPosition();
+ }
+ }
+
+ void setCurrentPosition(final RectF position) {
+ checkValidObject();
+ final RectF downsampledPosition = downscaleRect(position);
+ synchronized (ObjectTracker.this) {
+ setCurrentPositionNative(id,
+ downsampledPosition.left, downsampledPosition.top,
+ downsampledPosition.right, downsampledPosition.bottom);
+ }
+ }
+
+ private synchronized void updateTrackedPosition() {
+ checkValidObject();
+
+ final float[] delta = new float[4];
+ getTrackedPositionNative(id, delta);
+ lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ visibleInLastFrame = isObjectVisible(id);
+ }
+
+ public synchronized RectF getTrackedPositionInPreviewFrame() {
+ checkValidObject();
+
+ if (lastTrackedPosition == null) {
+ return null;
+ }
+ return upscaleRect(lastTrackedPosition);
+ }
+
+ synchronized long getLastExternalPositionTime() {
+ return lastExternalPositionTime;
+ }
+
+ public synchronized boolean visibleInLastPreviewFrame() {
+ return visibleInLastFrame;
+ }
+
+ private void checkValidObject() {
+ if (isDead) {
+ throw new RuntimeException("TrackedObject already removed from tracking!");
+ } else if (ObjectTracker.this != instance) {
+ throw new RuntimeException("TrackedObject created with another ObjectTracker!");
+ }
+ }
+ }
+
+ public synchronized TrackedObject trackObject(
+ final RectF position, final long timestamp, final byte[] frameData) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+ return new TrackedObject(position, timestamp, downsampledFrame);
+ }
+
+ public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
+ return new TrackedObject(position, lastTimestamp, frameData);
+ }
+
+ /*********************** NATIVE CODE *************************************/
+
+ /**
+ * This will contain an opaque pointer to the native ObjectTracker
+ */
+ private int nativeObjectTracker;
+
+ private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
+
+ protected native void registerNewObjectWithAppearanceNative(
+ String objectId, float x1, float y1, float x2, float y2, byte[] data);
+
+ protected native void setPreviousPositionNative(
+ String objectId, float x1, float y1, float x2, float y2, long timestamp);
+
+ protected native void setCurrentPositionNative(
+ String objectId, float x1, float y1, float x2, float y2);
+
+ protected native void forgetNative(String key);
+
+ protected native String getModelIdNative(String key);
+
+ protected native boolean haveObject(String key);
+ protected native boolean isObjectVisible(String key);
+ protected native float getCurrentCorrelation(String key);
+
+ protected native float getMatchScore(String key);
+
+ protected native void getTrackedPositionNative(String key, float[] points);
+
+ protected native void nextFrameNative(
+ byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
+
+ protected native void releaseMemoryNative();
+
+ protected native void getCurrentPositionNative(long timestamp,
+ final float positionX1, final float positionY1,
+ final float positionX2, final float positionY2,
+ final float[] delta);
+
+ protected native byte[] getKeypointsPacked(float scaleFactor);
+
+ protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
+
+ protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
+
+ protected static native void downsampleImageNative(
+ int width, int height, int rowStride, byte[] input, int factor, byte[] output);
+
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+}