aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE7
-rw-r--r--tensorflow/examples/android/AndroidManifest.xml9
-rw-r--r--tensorflow/examples/android/BUILD19
-rw-r--r--tensorflow/examples/android/README.md39
-rw-r--r--tensorflow/examples/android/jni/box_coder_jni.cc92
-rw-r--r--tensorflow/examples/android/jni/object_tracking/config.h300
-rw-r--r--tensorflow/examples/android/jni/object_tracking/flow_cache.h306
-rw-r--r--tensorflow/examples/android/jni/object_tracking/frame_pair.cc308
-rw-r--r--tensorflow/examples/android/jni/object_tracking/frame_pair.h103
-rw-r--r--tensorflow/examples/android/jni/object_tracking/geom.h319
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/gl_utils.h55
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image-inl.h642
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image.h346
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_data.h270
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_neon.cc270
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_utils.h301
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/integral_image.h187
-rw-r--r--tensorflow/examples/android/jni/object_tracking/jni_utils.h62
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint.h48
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc549
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint_detector.h133
-rw-r--r--tensorflow/examples/android/jni/object_tracking/log_streaming.h37
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_detector.cc27
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_detector.h232
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_model.h101
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.cc690
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.h271
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc463
-rw-r--r--tensorflow/examples/android/jni/object_tracking/optical_flow.cc490
-rw-r--r--tensorflow/examples/android/jni/object_tracking/optical_flow.h111
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/sprite.h205
-rw-r--r--tensorflow/examples/android/jni/object_tracking/time_log.cc29
-rw-r--r--tensorflow/examples/android/jni/object_tracking/time_log.h138
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.cc163
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.h191
-rw-r--r--tensorflow/examples/android/jni/object_tracking/utils.h386
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/utils_neon.cc151
-rw-r--r--tensorflow/examples/android/proto/box_coder.proto42
-rw-r--r--tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml30
-rw-r--r--tensorflow/examples/android/res/values/base-strings.xml5
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java11
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java317
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java218
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java381
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java649
45 files changed, 9684 insertions, 19 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 30aba396b8..20c0285084 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -29,6 +29,13 @@ new_http_archive(
sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364"
)
+new_http_archive(
+ name = "mobile_multibox",
+ build_file = "models.BUILD",
+ url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip",
+ sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f"
+)
+
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_http_archive(
diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml
index 0a48d3d50b..e388734564 100644
--- a/tensorflow/examples/android/AndroidManifest.xml
+++ b/tensorflow/examples/android/AndroidManifest.xml
@@ -41,6 +41,15 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
+
+ <activity android:name="org.tensorflow.demo.DetectorActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_detection">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
</application>
</manifest>
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index beb8337702..088133fed6 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -35,6 +35,7 @@ cc_binary(
"notap",
],
deps = [
+ ":demo_proto_lib_cc",
"//tensorflow/contrib/android:android_tensorflow_inference_jni",
"//tensorflow/core:android_tensorflow_lib",
LINKER_SCRIPT,
@@ -60,6 +61,7 @@ android_binary(
assets = [
"//tensorflow/examples/android/assets:asset_files",
"@inception5h//:model_files",
+ "@mobile_multibox//:model_files",
],
assets_dir = "",
custom_package = "org.tensorflow.demo",
@@ -111,3 +113,20 @@ filegroup(
)
exports_files(["AndroidManifest.xml"])
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library",
+)
+
+tf_proto_library(
+ name = "demo_proto_lib",
+ srcs = glob(
+ ["**/*.proto"],
+ ),
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index b0465f7faa..b6556cdef4 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -1,11 +1,24 @@
# TensorFlow Android Camera Demo
-This folder contains a simple camera-based demo application utilizing TensorFlow.
+This folder contains an example application utilizing TensorFlow for Android
+devices.
## Description
-This demo uses a Google Inception model to classify camera frames in real-time,
-displaying the top results in an overlay on the camera image.
+The demos in this folder are designed to give straightforward samples of using
+TensorFlow in mobile applications. Inference is done using the Java JNI API
+exposed by `tensorflow/contrib/android`.
+
+Current samples:
+
+1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java):
+ Uses the [Google Inception](https://arxiv.org/abs/1409.4842)
+ model to classify camera frames in real-time, displaying the top results
+ in an overlay on the camera image.
+2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
+ Demonstrates a model based on [Scalable Object Detection
+ using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to
+ localize and track people in the camera preview in real-time.
## To build/install/run
@@ -19,9 +32,9 @@ installed on your system.
3. The Android SDK and build tools may be obtained from:
https://developer.android.com/tools/revisions/build-tools.html
-The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13) must be
-uncommented with the paths filled in appropriately depending on where you
-installed the NDK and SDK. Otherwise an error such as:
+The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13)
+must be uncommented with the paths filled in appropriately depending on where
+you installed the NDK and SDK. Otherwise an error such as:
"The external label '//external:android/sdk' is not bound to anything" will
be reported.
@@ -29,19 +42,21 @@ The TensorFlow `GraphDef` that contains the model definition and weights
is not packaged in the repo because of its size. It will be downloaded
automatically via a new_http_archive defined in WORKSPACE.
-**Optional**: If you wish to place the model in your assets manually (E.g. for
-non-Bazel builds), remove the
-`inception_5` entry in `BUILD` and download the archive yourself to the
-`assets` directory in the source tree:
+**Optional**: If you wish to place the models in your assets manually (E.g. for
+non-Bazel builds), remove the `inception_5` and `mobile_multibox` entries in
+`BUILD` and download the archives yourself to the `assets` directory in the
+source tree:
```bash
$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip
+$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip -o /tmp/mobile_multibox_v1.zip
$ unzip /tmp/inception5h.zip -d tensorflow/examples/android/assets/
+$ unzip /tmp/mobile_multibox_v1.zip -d tensorflow/examples/android/assets/
```
-The labels file describing the possible classification will also be in the
-assets directory.
+The associated label and box prior files for the models will also be extracted
+into the assets directory.
After editing your WORKSPACE file to update the SDK/NDK configuration,
you may build the APK. Run this from your workspace root:
diff --git a/tensorflow/examples/android/jni/box_coder_jni.cc b/tensorflow/examples/android/jni/box_coder_jni.cc
new file mode 100644
index 0000000000..be85414fc1
--- /dev/null
+++ b/tensorflow/examples/android/jni/box_coder_jni.cc
@@ -0,0 +1,92 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file loads the box coder mappings.
+
+#include <android/asset_manager.h>
+#include <android/asset_manager_jni.h>
+#include <android/bitmap.h>
+
+#include <jni.h>
+#include <pthread.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/contrib/android/jni/jni_utils.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/proto/box_coder.pb.h"
+
+#define TENSORFLOW_METHOD(METHOD_NAME) \
+ Java_org_tensorflow_demo_TensorFlowMultiBoxDetector_##METHOD_NAME // NOLINT
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location,
+ jfloatArray priors);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location,
+ jfloatArray priors) {
+ AAssetManager* const asset_manager =
+ AAssetManager_fromJava(env, java_asset_manager);
+ LOG(INFO) << "Acquired AssetManager.";
+
+ const std::string location_str = GetString(env, location);
+
+ org_tensorflow_demo::MultiBoxCoderOptions multi_options;
+
+ LOG(INFO) << "Reading file to proto: " << location_str;
+ ReadFileToProtoOrDie(asset_manager, location_str.c_str(), &multi_options);
+
+ LOG(INFO) << "Read file. " << multi_options.box_coder_size() << " entries.";
+
+ jboolean iCopied = JNI_FALSE;
+ jfloat* values = env->GetFloatArrayElements(priors, &iCopied);
+
+ const int array_length = env->GetArrayLength(priors);
+ LOG(INFO) << "Array length: " << array_length
+ << " (/8 = " << (array_length / 8) << ")";
+ CHECK_EQ(array_length % 8, 0);
+
+ const int num_items =
+ std::min(array_length / 8, multi_options.box_coder_size());
+
+ for (int i = 0; i < num_items; ++i) {
+ const org_tensorflow_demo::BoxCoderOptions& options =
+ multi_options.box_coder(i);
+
+ for (int j = 0; j < 4; ++j) {
+ const org_tensorflow_demo::BoxCoderPrior& prior = options.priors(j);
+ values[i * 8 + j * 2] = prior.mean();
+ values[i * 8 + j * 2 + 1] = prior.stddev();
+ }
+ }
+ env->ReleaseFloatArrayElements(priors, values, 0);
+
+ LOG(INFO) << "Read " << num_items << " options";
+}
diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h
new file mode 100644
index 0000000000..86e9fc71b6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/config.h
@@ -0,0 +1,300 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+
+#include <math.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+
+namespace tf_tracking {
+
+// Arbitrary keypoint type ids for labeling the origin of tracked keypoints.
+enum KeypointType {
+ KEYPOINT_TYPE_DEFAULT = 0,
+ KEYPOINT_TYPE_FAST = 1,
+ KEYPOINT_TYPE_INTEREST = 2
+};
+
+// Struct that can be used to more richly store the results of a detection
+// than a single number, while still maintaining comparability.
+struct MatchScore {
+ explicit MatchScore(double val) : value(val) {}
+ MatchScore() { value = 0.0; }
+
+ double value;
+
+ MatchScore& operator+(const MatchScore& rhs) {
+ value += rhs.value;
+ return *this;
+ }
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const MatchScore& detection) {
+ stream << detection.value;
+ return stream;
+ }
+};
+inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value < cC2.value;
+}
+inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value > cC2.value;
+}
+inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value >= cC2.value;
+}
+inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value <= cC2.value;
+}
+
+// Fixed seed used for all random number generators.
+static const int kRandomNumberSeed = 11111;
+
+// TODO(andrewharp): Move as many of these settings as possible into a settings
+// object which can be passed in from Java at runtime.
+
+// Whether or not to use ESM instead of LK flow.
+static const bool kUseEsm = false;
+
+// This constant gets added to the diagonal of the Hessian
+// before solving for translation in 2dof ESM.
+// It ensures better behavior especially in the absence of
+// strong texture.
+static const int kEsmRegularizer = 20;
+
+// Do we want to brightness-normalize each keypoint patch when we compute
+// its flow using ESM?
+static const bool kDoBrightnessNormalize = true;
+
+// Whether or not to use fixed-point interpolated pixel lookups in optical flow.
+#define USE_FIXED_POINT_FLOW 1
+
+// Whether to normalize keypoint windows for intensity in LK optical flow.
+// This is a define for now because it helps keep the code streamlined.
+#define NORMALIZE 1
+
+// Number of keypoints to store per frame.
+static const int kMaxKeypoints = 76;
+
+// Keypoint detection.
+static const int kMaxTempKeypoints = 1024;
+
+// Number of floats each keypoint takes up when exporting to an array.
+static const int kKeypointStep = 7;
+
+// Number of frame deltas to keep around in the circular queue.
+static const int kNumFrames = 512;
+
+// Number of iterations to do tracking on each keypoint at each pyramid level.
+static const int kNumIterations = 3;
+
+// The number of bins (on a side) to divide each bin from the previous
+// cache level into. Higher numbers will decrease performance by increasing
+// cache misses, but mean that cache hits are more locally relevant.
+static const int kCacheBranchFactor = 2;
+
+// Number of levels to put in the cache.
+// Each level of the cache is a square grid of bins, length:
+// branch_factor^(level - 1) on each side.
+//
+// This may be greater than kNumPyramidLevels. Setting it to 0 means no
+// caching is enabled.
+static const int kNumCacheLevels = 3;
+
+// The level at which the cache pyramid gets cut off and replaced by a matrix
+// transform if such a matrix has been provided to the cache.
+static const int kCacheCutoff = 1;
+
+static const int kNumPyramidLevels = 4;
+
+// The minimum number of keypoints needed in an object's area.
+static const int kMaxKeypointsForObject = 16;
+
+// Minimum number of pyramid levels to use after getting cached value.
+// This allows fine-scale adjustment from the cached value, which is taken
+// from the center of the corresponding top cache level box.
+// Can be [0, kNumPyramidLevels).
+static const int kMinNumPyramidLevelsToUseForAdjustment = 1;
+
+// Window size to integrate over to find local image derivative.
+static const int kFlowIntegrationWindowSize = 3;
+
+// Total area of integration windows.
+static const int kFlowArraySize =
+ (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1);
+
+// Error that's considered good enough to early abort tracking.
+static const float kTrackingAbortThreshold = 0.03f;
+
+// Maximum number of deviations a keypoint-correspondence delta can be from the
+// weighted average before being thrown out for region-based queries.
+static const float kNumDeviations = 2.0f;
+
+// The length of the allowed delta between the forward and the backward
+// flow deltas in terms of the length of the forward flow vector.
+static const float kMaxForwardBackwardErrorAllowed = 0.5f;
+
+// Threshold for pixels to be considered different.
+static const int kFastDiffAmount = 10;
+
+// How far from edge of frame to stop looking for FAST keypoints.
+static const int kFastBorderBuffer = 10;
+
+// Determines if non-detected arbitrary keypoints should be added to regions.
+// This will help if no keypoints have been detected in the region yet.
+static const bool kAddArbitraryKeypoints = true;
+
+// How many arbitrary keypoints to add along each axis as candidates for each
+// region?
+static const int kNumToAddAsCandidates = 1;
+
+// In terms of region dimensions, how closely can we place keypoints
+// next to each other?
+static const float kClosestPercent = 0.6f;
+
+// How many FAST qualifying pixels must be connected to a pixel for it to be
+// considered a candidate keypoint for Harris filtering.
+static const int kMinNumConnectedForFastKeypoint = 8;
+
+// Size of the window to integrate over for Harris filtering.
+// Compare to kFlowIntegrationWindowSize.
+static const int kHarrisWindowSize = 2;
+
+
+// DETECTOR PARAMETERS
+
+// Before relocalizing, make sure the new proposed position is better than
+// the existing position by a small amount to prevent thrashing.
+static const MatchScore kMatchScoreBuffer(0.01f);
+
+// Minimum score a tracked object can have and still be considered a match.
+// TODO(andrewharp): Make this a per detector thing.
+static const MatchScore kMinimumMatchScore(0.5f);
+
+static const float kMinimumCorrelationForTracking = 0.4f;
+
+static const MatchScore kMatchScoreForImmediateTermination(0.0f);
+
+// Run the detector every N frames.
+static const int kDetectEveryNFrames = 4;
+
+// How many features does each feature_set contain?
+static const int kFeaturesPerFeatureSet = 10;
+
+// The number of FeatureSets managed by the object detector.
+// More FeatureSets can increase recall at the cost of performance.
+static const int kNumFeatureSets = 7;
+
+// How many FeatureSets must respond affirmatively for a candidate descriptor
+// and position to be given more thorough attention?
+static const int kNumFeatureSetsForCandidate = 2;
+
+// How large the thumbnails used for correlation validation are. Used for both
+// width and height.
+static const int kNormalizedThumbnailSize = 11;
+
+// The area of intersection divided by union for the bounding boxes that tells
+// if this tracking has slipped enough to invalidate all unlocked examples.
+static const float kPositionOverlapThreshold = 0.6f;
+
+// The number of detection failures allowed before an object goes invisible.
+// Tracking will still occur, so if it is actually still being tracked and
+// comes back into a detectable position, it's likely to be found.
+static const int kMaxNumDetectionFailures = 4;
+
+
+// Minimum square size to scan with sliding window.
+static const float kScanMinSquareSize = 16.0f;
+
+// Minimum square size to scan with sliding window.
+static const float kScanMaxSquareSize = 64.0f;
+
+// Scale difference for consecutive scans of the sliding window.
+static const float kScanScaleFactor = sqrtf(2.0f);
+
+// Step size for sliding window.
+static const int kScanStepSize = 10;
+
+
+// How tightly to pack the descriptor boxes for confirmed exemplars.
+static const float kLockedScaleFactor = 1 / sqrtf(2.0f);
+
+// How tightly to pack the descriptor boxes for unconfirmed exemplars.
+static const float kUnlockedScaleFactor = 1 / 2.0f;
+
+// How tightly the boxes to scan centered at the last known position will be
+// packed.
+static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f);
+
+// The bounds on how close a new object example must be to existing object
+// examples for detection to be valid.
+static const float kMinCorrelationForNewExample = 0.75f;
+static const float kMaxCorrelationForNewExample = 0.99f;
+
+
+// The number of safe tries an exemplar has after being created before
+// missed detections count against it.
+static const int kFreeTries = 5;
+
+// A false positive is worth this many missed detections.
+static const int kFalsePositivePenalty = 5;
+
+struct ObjectDetectorConfig {
+ const Size image_size;
+
+ explicit ObjectDetectorConfig(const Size& image_size)
+ : image_size(image_size) {}
+ virtual ~ObjectDetectorConfig() = default;
+};
+
+struct KeypointDetectorConfig {
+ const Size image_size;
+
+ bool detect_skin;
+
+ explicit KeypointDetectorConfig(const Size& image_size)
+ : image_size(image_size),
+ detect_skin(false) {}
+};
+
+
+struct OpticalFlowConfig {
+ const Size image_size;
+
+ explicit OpticalFlowConfig(const Size& image_size)
+ : image_size(image_size) {}
+};
+
+struct TrackerConfig {
+ const Size image_size;
+ KeypointDetectorConfig keypoint_detector_config;
+ OpticalFlowConfig flow_config;
+ bool always_track;
+
+ float object_box_scale_factor_for_features;
+
+ explicit TrackerConfig(const Size& image_size)
+ : image_size(image_size),
+ keypoint_detector_config(image_size),
+ flow_config(image_size),
+ always_track(false),
+ object_box_scale_factor_for_features(1.0f) {}
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
new file mode 100644
index 0000000000..8813ab6d71
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
@@ -0,0 +1,306 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+// Class that helps OpticalFlow to speed up flow computation
+// by caching coarse-grained flow.
+class FlowCache {
+ public:
+ explicit FlowCache(const OpticalFlowConfig* const config)
+ : config_(config),
+ image_size_(config->image_size),
+ optical_flow_(config),
+ fullframe_matrix_(NULL) {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ const int curr_dims = BlockDimForCacheLevel(i);
+ has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
+ displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
+ }
+ }
+
+ ~FlowCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ SAFE_DELETE(has_cache_[i]);
+ SAFE_DELETE(displacements_[i]);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ void NextFrame(ImageData* const new_frame,
+ const float* const align_matrix23) {
+ ClearCache();
+ SetFullframeAlignmentMatrix(align_matrix23);
+ optical_flow_.NextFrame(new_frame);
+ }
+
+ void ClearCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ has_cache_[i]->Clear(false);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ // Finds the flow at a point, using the cache for performance.
+ bool FindFlowAtPoint(const float u_x, const float u_y,
+ float* const flow_x, float* const flow_y) const {
+ // Get the best guess from the cache.
+ const Point2f guess_from_cache = LookupGuess(u_x, u_y);
+
+ *flow_x = guess_from_cache.x;
+ *flow_y = guess_from_cache.y;
+
+ // Now refine the guess using the image pyramid.
+ for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
+ pyramid_level >= 0; --pyramid_level) {
+ if (!optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Determines the displacement of a point, and uses that to calculate a new
+ // position.
+ // Returns true iff the displacement determination worked and the new position
+ // is in the image.
+ bool FindNewPositionOfPoint(const float u_x, const float u_y,
+ float* final_x, float* final_y) const {
+ float flow_x;
+ float flow_y;
+ if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
+ return false;
+ }
+
+ // Add in the displacement to get the final position.
+ *final_x = u_x + flow_x;
+ *final_y = u_y + flow_y;
+
+ // Assign the best guess, if we're still in the image.
+ if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
+ InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // Comparison function for qsort.
+ static int Compare(const void* a, const void* b) {
+ return *reinterpret_cast<const float*>(a) -
+ *reinterpret_cast<const float*>(b);
+ }
+
+ // Returns the median flow within the given bounding box as determined
+ // by a grid_width x grid_height grid.
+ Point2f GetMedianFlow(const BoundingBox& bounding_box,
+ const bool filter_by_fb_error,
+ const int grid_width,
+ const int grid_height) const {
+ const int kMaxPoints = 100;
+ SCHECK(grid_width * grid_height <= kMaxPoints,
+ "Too many points for Median flow!");
+
+ const BoundingBox valid_box = bounding_box.Intersect(
+ BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
+
+ if (valid_box.GetArea() <= 0.0f) {
+ return Point2f(0, 0);
+ }
+
+ float x_deltas[kMaxPoints];
+ float y_deltas[kMaxPoints];
+
+ int curr_offset = 0;
+ for (int i = 0; i < grid_width; ++i) {
+ for (int j = 0; j < grid_height; ++j) {
+ const float x_in = valid_box.left_ +
+ (valid_box.GetWidth() * i) / (grid_width - 1);
+
+ const float y_in = valid_box.top_ +
+ (valid_box.GetHeight() * j) / (grid_height - 1);
+
+ float curr_flow_x;
+ float curr_flow_y;
+ const bool success = FindNewPositionOfPoint(x_in, y_in,
+ &curr_flow_x, &curr_flow_y);
+
+ if (success) {
+ x_deltas[curr_offset] = curr_flow_x;
+ y_deltas[curr_offset] = curr_flow_y;
+ ++curr_offset;
+ } else {
+ LOGW("Tracking failure!");
+ }
+ }
+ }
+
+ if (curr_offset > 0) {
+ qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
+ qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
+
+ return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
+ }
+
+ LOGW("No points were valid!");
+ return Point2f(0, 0);
+ }
+
+ void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
+ if (align_matrix23 != NULL) {
+ if (fullframe_matrix_ == NULL) {
+ fullframe_matrix_ = new float[6];
+ }
+
+ memcpy(fullframe_matrix_, align_matrix23,
+ 6 * sizeof(fullframe_matrix_[0]));
+ }
+ }
+
+ private:
+ Point2f LookupGuessFromLevel(
+ const int cache_level, const float x, const float y) const {
+ // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
+
+ // Cutoff at the target level and use the matrix transform instead.
+ if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
+ const float xnew = x * fullframe_matrix_[0] +
+ y * fullframe_matrix_[1] +
+ fullframe_matrix_[2];
+ const float ynew = x * fullframe_matrix_[3] +
+ y * fullframe_matrix_[4] +
+ fullframe_matrix_[5];
+
+ return Point2f(xnew - x, ynew - y);
+ }
+
+ const int level_dim = BlockDimForCacheLevel(cache_level);
+ const int pixels_per_cache_block_x =
+ (image_size_.width + level_dim - 1) / level_dim;
+ const int pixels_per_cache_block_y =
+ (image_size_.height + level_dim - 1) / level_dim;
+ const int index_x = x / pixels_per_cache_block_x;
+ const int index_y = y / pixels_per_cache_block_y;
+
+ Point2f displacement;
+ if (!(*has_cache_[cache_level])[index_y][index_x]) {
+ (*has_cache_[cache_level])[index_y][index_x] = true;
+
+ // Get the lower cache level's best guess, if it exists.
+ displacement = cache_level >= kNumCacheLevels - 1 ?
+ Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
+ // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
+ // best_guess.x, best_guess.y);
+
+ // Find the center of the block.
+ const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
+ const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
+ const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
+
+ // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
+ // "Querying %5.2f, %5.2f at pyramid level %d, ",
+ // cache_level, index_x, index_y,
+ // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
+ // center_x, center_y, pyramid_level);
+
+ // TODO(andrewharp): Turn on FB error filtering.
+ const bool success = optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, center_x, center_y, false,
+ &displacement.x, &displacement.y);
+
+ if (!success) {
+ LOGV("Computation of cached value failed for level %d!", cache_level);
+ }
+
+ // Store the value for later use.
+ (*displacements_[cache_level])[index_y][index_x] = displacement;
+ } else {
+ displacement = (*displacements_[cache_level])[index_y][index_x];
+ }
+
+ // LOGI("Returning %5.2f, %5.2f for level %d",
+ // displacement.x, displacement.y, cache_level);
+ return displacement;
+ }
+
+ Point2f LookupGuess(const float x, const float y) const {
+ if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
+ return Point2f(0, 0);
+ }
+
+ // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
+ if (kNumCacheLevels > 0) {
+ return LookupGuessFromLevel(0, x, y);
+ } else {
+ return Point2f(0, 0);
+ }
+ }
+
+ // Returns the number of cache bins in each dimension for a given level
+ // of the cache.
+ int BlockDimForCacheLevel(const int cache_level) const {
+ // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
+ // thus if there are 4 cache levels, requesting level 3 (0-based) should
+ // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
+ // and so on.
+ int block_dim = kNumCacheLevels;
+ for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
+ --curr_level) {
+ block_dim *= kCacheBranchFactor;
+ }
+ return block_dim;
+ }
+
+ // Returns the level of the image pyramid that a given cache level maps to.
+ int PyramidLevelForCacheLevel(const int cache_level) const {
+ // Higher cache and pyramid levels have smaller dimensions. The highest
+ // cache level should refer to the highest image pyramid level. The
+ // lower, finer image pyramid levels are uncached (assuming
+ // kNumCacheLevels < kNumPyramidLevels).
+ return cache_level + (kNumPyramidLevels - kNumCacheLevels);
+ }
+
+ const OpticalFlowConfig* const config_;
+
+ const Size image_size_;
+ OpticalFlow optical_flow_;
+
+ float* fullframe_matrix_;
+
+ // Whether this value is currently present in the cache.
+ Image<bool>* has_cache_[kNumCacheLevels];
+
+ // The cached displacement values.
+ Image<Point2f>* displacements_[kNumCacheLevels];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc
new file mode 100644
index 0000000000..fa86e2363c
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc
@@ -0,0 +1,308 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <float.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+
+namespace tf_tracking {
+
+void FramePair::Init(const int64 start_time, const int64 end_time) {
+ start_time_ = start_time;
+ end_time_ = end_time;
+ memset(optical_flow_found_keypoint_, false,
+ sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
+ number_of_keypoints_ = 0;
+}
+
+void FramePair::AdjustBox(const BoundingBox box,
+ float* const translation_x,
+ float* const translation_y,
+ float* const scale_x,
+ float* const scale_y) const {
+ static float weights[kMaxKeypoints];
+ static Point2f deltas[kMaxKeypoints];
+ memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
+
+ BoundingBox resized_box(box);
+ resized_box.Scale(0.4f, 0.4f);
+ FillWeights(resized_box, weights);
+ FillTranslations(deltas);
+
+ const Point2f translation = GetWeightedMedian(weights, deltas);
+
+ *translation_x = translation.x;
+ *translation_y = translation.y;
+
+ const Point2f old_center = box.GetCenter();
+ const int good_scale_points =
+ FillScales(old_center, translation, weights, deltas);
+
+ // Default scale factor is 1 for x and y.
+ *scale_x = 1.0f;
+ *scale_y = 1.0f;
+
+ // The assumption is that all deltas that make it to this stage with a
+ // correspondending optical_flow_found_keypoint_[i] == true are not in
+ // themselves degenerate.
+ //
+ // The degeneracy with scale arose because if the points are too close to the
+ // center of the objects, the scale ratio determination might be incalculable.
+ //
+ // The check for kMinNumInRange is not a degeneracy check, but merely an
+ // attempt to ensure some sort of stability. The actual degeneracy check is in
+ // the comparison to EPSILON in FillScales (which I've updated to return the
+ // number good remaining as well).
+ static const int kMinNumInRange = 5;
+ if (good_scale_points >= kMinNumInRange) {
+ const float scale_factor = GetWeightedMedianScale(weights, deltas);
+
+ if (scale_factor > 0.0f) {
+ *scale_x = scale_factor;
+ *scale_y = scale_factor;
+ }
+ }
+}
+
+int FramePair::FillWeights(const BoundingBox& box,
+ float* const weights) const {
+ // Compute the max score.
+ float max_score = -FLT_MAX;
+ float min_score = FLT_MAX;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (optical_flow_found_keypoint_[i]) {
+ max_score = MAX(max_score, frame1_keypoints_[i].score_);
+ min_score = MIN(min_score, frame1_keypoints_[i].score_);
+ }
+ }
+
+ int num_in_range = 0;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ weights[i] = 0.0f;
+ continue;
+ }
+
+ const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
+ if (in_box) {
+ ++num_in_range;
+ }
+
+ // The weighting based off distance. Anything within the bounding box
+ // has a weight of 1, and everything outside of that is within the range
+ // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
+ float distance_score = 1.0f;
+ if (!in_box) {
+ const Point2f initial = box.GetCenter();
+ const float sq_x_dist =
+ Square(initial.x - frame1_keypoints_[i].pos_.x);
+ const float sq_y_dist =
+ Square(initial.y - frame1_keypoints_[i].pos_.y);
+ const float squared_half_width = Square(box.GetWidth() / 2.0f);
+ const float squared_half_height = Square(box.GetHeight() / 2.0f);
+
+ static const float kOutOfBoxMultiplier = 0.5f;
+ distance_score = kOutOfBoxMultiplier *
+ MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
+ }
+
+ // The weighting based on relative score strength. kBaseScore - 1.0f.
+ float intrinsic_score = 1.0f;
+ if (max_score > min_score) {
+ static const float kBaseScore = 0.5f;
+ intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
+ (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
+ }
+
+ // The final score will be in the range [0, 1].
+ weights[i] = distance_score * intrinsic_score;
+ }
+
+ return num_in_range;
+}
+
+void FramePair::FillTranslations(Point2f* const translations) const {
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ continue;
+ }
+ translations[i].x =
+ frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
+ translations[i].y =
+ frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
+ }
+}
+
+int FramePair::FillScales(const Point2f& old_center,
+ const Point2f& translation,
+ float* const weights,
+ Point2f* const scales) const {
+ int num_good = 0;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ continue;
+ }
+
+ const Keypoint keypoint1 = frame1_keypoints_[i];
+ const Keypoint keypoint2 = frame2_keypoints_[i];
+
+ const float dist1_x = keypoint1.pos_.x - old_center.x;
+ const float dist1_y = keypoint1.pos_.y - old_center.y;
+
+ const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
+ const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
+
+ // Make sure that the scale makes sense; points too close to the center
+ // will result in either NaNs or infinite results for scale due to
+ // limited tracking and floating point resolution.
+ // Also check that the parity of the points is the same with respect to
+ // x and y, as we can't really make sense of data that has flipped.
+ if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
+ (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
+ ((dist2_y > EPSILON && dist1_y > EPSILON) ||
+ (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
+ scales[i].x = dist2_x / dist1_x;
+ scales[i].y = dist2_y / dist1_y;
+ ++num_good;
+ } else {
+ weights[i] = 0.0f;
+ scales[i].x = 1.0f;
+ scales[i].y = 1.0f;
+ }
+ }
+ return num_good;
+}
+
+struct WeightedDelta {
+ float weight;
+ float delta;
+};
+
+// Sort by delta, not by weight.
+inline int WeightedDeltaCompare(const void* const a, const void* const b) {
+ return (reinterpret_cast<const WeightedDelta*>(a)->delta -
+ reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
+}
+
+// Returns the median delta from a sorted set of weighted deltas.
+static float GetMedian(const int num_items,
+ const WeightedDelta* const weighted_deltas,
+ const float sum) {
+ if (num_items == 0 || sum < EPSILON) {
+ return 0.0f;
+ }
+
+ float current_weight = 0.0f;
+ const float target_weight = sum / 2.0f;
+ for (int i = 0; i < num_items; ++i) {
+ if (weighted_deltas[i].weight > 0.0f) {
+ current_weight += weighted_deltas[i].weight;
+ if (current_weight >= target_weight) {
+ return weighted_deltas[i].delta;
+ }
+ }
+ }
+ LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
+ return 0.0f;
+}
+
+Point2f FramePair::GetWeightedMedian(
+ const float* const weights, const Point2f* const deltas) const {
+ Point2f median_delta;
+
+ // TODO(andrewharp): only sort deltas that could possibly have an effect.
+ static WeightedDelta weighted_deltas[kMaxKeypoints];
+
+ // Compute median X value.
+ {
+ float total_weight = 0.0f;
+
+ // Compute weighted mean and deltas.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i].delta = deltas[i].x;
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+ qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+ median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
+ }
+
+ // Compute median Y value.
+ {
+ float total_weight = 0.0f;
+
+ // Compute weighted mean and deltas.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ weighted_deltas[i].delta = deltas[i].y;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+ qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+ median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
+ }
+
+ return median_delta;
+}
+
+float FramePair::GetWeightedMedianScale(
+ const float* const weights, const Point2f* const deltas) const {
+ float median_delta;
+
+ // TODO(andrewharp): only sort deltas that could possibly have an effect.
+ static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
+
+ // Compute median scale value across x and y.
+ {
+ float total_weight = 0.0f;
+
+ // Add X values.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i].delta = deltas[i].x;
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+
+ // Add Y values.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
+ const float weight = weights[i];
+ weighted_deltas[i + kMaxKeypoints].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+
+ qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+
+ median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
+ }
+
+ return median_delta;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
new file mode 100644
index 0000000000..3f2559a5e0
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
@@ -0,0 +1,103 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+namespace tf_tracking {
+
+// A class that records keypoint correspondences from pairs of
+// consecutive frames.
+class FramePair {
+ public:
+ FramePair()
+ : start_time_(0),
+ end_time_(0),
+ number_of_keypoints_(0) {}
+
+ // Cleans up the FramePair so that they can be reused.
+ void Init(const int64 start_time, const int64 end_time);
+
+ void AdjustBox(const BoundingBox box,
+ float* const translation_x,
+ float* const translation_y,
+ float* const scale_x,
+ float* const scale_y) const;
+
+ private:
+ // Returns the weighted median of the given deltas, computed independently on
+ // x and y. Returns 0,0 in case of failure. The assumption is that a
+ // translation of 0.0 in the degenerate case is the best that can be done, and
+ // should not be considered an error.
+ //
+ // In the case of scale, a slight exception is made just to be safe and
+ // there is a check for 0.0 explicitly, but that shouldn't ever be possible to
+ // happen naturally because of the non-zero + parity checks in FillScales.
+ Point2f GetWeightedMedian(const float* const weights,
+ const Point2f* const deltas) const;
+
+ float GetWeightedMedianScale(const float* const weights,
+ const Point2f* const deltas) const;
+
+ // Weights points based on the query_point and cutoff_dist.
+ int FillWeights(const BoundingBox& box,
+ float* const weights) const;
+
+ // Fills in the array of deltas with the translations of the points
+ // between frames.
+ void FillTranslations(Point2f* const translations) const;
+
+ // Fills in the array of deltas with the relative scale factor of points
+ // relative to a given center. Has the ability to override the weight to 0 if
+ // a degenerate scale is detected.
+ // Translation is the amount the center of the box has moved from one frame to
+ // the next.
+ int FillScales(const Point2f& old_center,
+ const Point2f& translation,
+ float* const weights,
+ Point2f* const scales) const;
+
+ // TODO(andrewharp): Make these private.
+ public:
+ // The time at frame1.
+ int64 start_time_;
+
+ // The time at frame2.
+ int64 end_time_;
+
+ // This array will contain the keypoints found in frame 1.
+ Keypoint frame1_keypoints_[kMaxKeypoints];
+
+ // Contain the locations of the keypoints from frame 1 in frame 2.
+ Keypoint frame2_keypoints_[kMaxKeypoints];
+
+ // The number of keypoints in frame 1.
+ int number_of_keypoints_;
+
+ // Keeps track of which keypoint correspondences were actually found from one
+ // frame to another.
+ // The i-th element of this array will be non-zero if and only if the i-th
+ // keypoint of frame 1 was found in frame 2.
+ bool optical_flow_found_keypoint_[kMaxKeypoints];
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(FramePair);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h
new file mode 100644
index 0000000000..5d5249cd97
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/geom.h
@@ -0,0 +1,319 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+struct Size {
+ Size(const int width, const int height) : width(width), height(height) {}
+
+ int width;
+ int height;
+};
+
+
+class Point2f {
+ public:
+ Point2f() : x(0.0f), y(0.0f) {}
+ Point2f(const float x, const float y) : x(x), y(y) {}
+
+ inline Point2f operator- (const Point2f& that) const {
+ return Point2f(this->x - that.x, this->y - that.y);
+ }
+
+ inline Point2f operator+ (const Point2f& that) const {
+ return Point2f(this->x + that.x, this->y + that.y);
+ }
+
+ inline Point2f& operator+= (const Point2f& that) {
+ this->x += that.x;
+ this->y += that.y;
+ return *this;
+ }
+
+ inline Point2f& operator-= (const Point2f& that) {
+ this->x -= that.x;
+ this->y -= that.y;
+ return *this;
+ }
+
+ inline Point2f operator- (const Point2f& that) {
+ return Point2f(this->x - that.x, this->y - that.y);
+ }
+
+ inline float LengthSquared() {
+ return Square(this->x) + Square(this->y);
+ }
+
+ inline float Length() {
+ return sqrtf(LengthSquared());
+ }
+
+ inline float DistanceSquared(const Point2f& that) {
+ return Square(this->x - that.x) + Square(this->y - that.y);
+ }
+
+ inline float Distance(const Point2f& that) {
+ return sqrtf(DistanceSquared(that));
+ }
+
+ float x;
+ float y;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
+ stream << point.x << "," << point.y;
+ return stream;
+}
+
+class BoundingBox {
+ public:
+ BoundingBox()
+ : left_(0),
+ top_(0),
+ right_(0),
+ bottom_(0) {}
+
+ BoundingBox(const BoundingBox& bounding_box)
+ : left_(bounding_box.left_),
+ top_(bounding_box.top_),
+ right_(bounding_box.right_),
+ bottom_(bounding_box.bottom_) {
+ SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
+ SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
+ }
+
+ BoundingBox(const float left,
+ const float top,
+ const float right,
+ const float bottom)
+ : left_(left),
+ top_(top),
+ right_(right),
+ bottom_(bottom) {
+ SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
+ SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
+ }
+
+ BoundingBox(const Point2f& point1, const Point2f& point2)
+ : left_(MIN(point1.x, point2.x)),
+ top_(MIN(point1.y, point2.y)),
+ right_(MAX(point1.x, point2.x)),
+ bottom_(MAX(point1.y, point2.y)) {}
+
+ inline void CopyToArray(float* const bounds_array) const {
+ bounds_array[0] = left_;
+ bounds_array[1] = top_;
+ bounds_array[2] = right_;
+ bounds_array[3] = bottom_;
+ }
+
+ inline float GetWidth() const {
+ return right_ - left_;
+ }
+
+ inline float GetHeight() const {
+ return bottom_ - top_;
+ }
+
+ inline float GetArea() const {
+ const float width = GetWidth();
+ const float height = GetHeight();
+ if (width <= 0 || height <= 0) {
+ return 0.0f;
+ }
+
+ return width * height;
+ }
+
+ inline Point2f GetCenter() const {
+ return Point2f((left_ + right_) / 2.0f,
+ (top_ + bottom_) / 2.0f);
+ }
+
+ inline bool ValidBox() const {
+ return GetArea() > 0.0f;
+ }
+
+ // Returns a bounding box created from the overlapping area of these two.
+ inline BoundingBox Intersect(const BoundingBox& that) const {
+ const float new_left = MAX(this->left_, that.left_);
+ const float new_right = MIN(this->right_, that.right_);
+
+ if (new_left >= new_right) {
+ return BoundingBox();
+ }
+
+ const float new_top = MAX(this->top_, that.top_);
+ const float new_bottom = MIN(this->bottom_, that.bottom_);
+
+ if (new_top >= new_bottom) {
+ return BoundingBox();
+ }
+
+ return BoundingBox(new_left, new_top, new_right, new_bottom);
+ }
+
+ // Returns a bounding box that can contain both boxes.
+ inline BoundingBox Union(const BoundingBox& that) const {
+ return BoundingBox(MIN(this->left_, that.left_),
+ MIN(this->top_, that.top_),
+ MAX(this->right_, that.right_),
+ MAX(this->bottom_, that.bottom_));
+ }
+
+ inline float PascalScore(const BoundingBox& that) const {
+ SCHECK(GetArea() > 0.0f, "Empty bounding box!");
+ SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
+
+ const float intersect_area = this->Intersect(that).GetArea();
+
+ if (intersect_area <= 0) {
+ return 0;
+ }
+
+ const float score =
+ intersect_area / (GetArea() + that.GetArea() - intersect_area);
+ SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
+ return score;
+ }
+
+ inline bool Intersects(const BoundingBox& that) const {
+ return InRange(that.left_, left_, right_)
+ || InRange(that.right_, left_, right_)
+ || InRange(that.top_, top_, bottom_)
+ || InRange(that.bottom_, top_, bottom_);
+ }
+
+ // Returns whether another bounding box is completely inside of this bounding
+ // box. Sharing edges is ok.
+ inline bool Contains(const BoundingBox& that) const {
+ return that.left_ >= left_ &&
+ that.right_ <= right_ &&
+ that.top_ >= top_ &&
+ that.bottom_ <= bottom_;
+ }
+
+ inline bool Contains(const Point2f& point) const {
+ return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
+ }
+
+ inline void Shift(const Point2f shift_amount) {
+ left_ += shift_amount.x;
+ top_ += shift_amount.y;
+ right_ += shift_amount.x;
+ bottom_ += shift_amount.y;
+ }
+
+ inline void ScaleOrigin(const float scale_x, const float scale_y) {
+ left_ *= scale_x;
+ right_ *= scale_x;
+ top_ *= scale_y;
+ bottom_ *= scale_y;
+ }
+
+ inline void Scale(const float scale_x, const float scale_y) {
+ const Point2f center = GetCenter();
+ const float half_width = GetWidth() / 2.0f;
+ const float half_height = GetHeight() / 2.0f;
+
+ left_ = center.x - half_width * scale_x;
+ right_ = center.x + half_width * scale_x;
+
+ top_ = center.y - half_height * scale_y;
+ bottom_ = center.y + half_height * scale_y;
+ }
+
+ float left_;
+ float top_;
+ float right_;
+ float bottom_;
+};
+inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
+ stream << "[" << box.left_ << " - " << box.right_
+ << ", " << box.top_ << " - " << box.bottom_
+ << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
+ return stream;
+}
+
+
+class BoundingSquare {
+ public:
+ BoundingSquare(const float x, const float y, const float size)
+ : x_(x), y_(y), size_(size) {}
+
+ explicit BoundingSquare(const BoundingBox& box)
+ : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
+#ifdef SANITY_CHECKS
+ if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
+ LOG(WARNING) << "This is not a square: " << box << std::endl;
+ }
+#endif
+ }
+
+ inline BoundingBox ToBoundingBox() const {
+ return BoundingBox(x_, y_, x_ + size_, y_ + size_);
+ }
+
+ inline bool ValidBox() {
+ return size_ > 0.0f;
+ }
+
+ inline void Shift(const Point2f shift_amount) {
+ x_ += shift_amount.x;
+ y_ += shift_amount.y;
+ }
+
+ inline void Scale(const float scale) {
+ const float new_size = size_ * scale;
+ const float position_diff = (new_size - size_) / 2.0f;
+ x_ -= position_diff;
+ y_ -= position_diff;
+ size_ = new_size;
+ }
+
+ float x_;
+ float y_;
+ float size_;
+};
+inline std::ostream& operator<<(std::ostream& stream,
+ const BoundingSquare& square) {
+ stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
+ return stream;
+}
+
+
+inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
+ const float size) {
+ const float width_diff = (original_box.GetWidth() - size) / 2.0f;
+ const float height_diff = (original_box.GetHeight() - size) / 2.0f;
+ return BoundingSquare(original_box.left_ + width_diff,
+ original_box.top_ + height_diff,
+ size);
+}
+
+inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
+ return GetCenteredSquare(
+ original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
new file mode 100755
index 0000000000..bd5c233f4f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
@@ -0,0 +1,55 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+
+namespace tf_tracking {
+
+// Draws a box at the given position.
+inline static void DrawBox(const BoundingBox& bounding_box) {
+ const GLfloat line[] = {
+ bounding_box.left_, bounding_box.bottom_,
+ bounding_box.left_, bounding_box.top_,
+ bounding_box.left_, bounding_box.top_,
+ bounding_box.right_, bounding_box.top_,
+ bounding_box.right_, bounding_box.top_,
+ bounding_box.right_, bounding_box.bottom_,
+ bounding_box.right_, bounding_box.bottom_,
+ bounding_box.left_, bounding_box.bottom_
+ };
+
+ glVertexPointer(2, GL_FLOAT, 0, line);
+ glEnableClientState(GL_VERTEX_ARRAY);
+
+ glDrawArrays(GL_LINES, 0, 8);
+}
+
+
+// Changes the coordinate system such that drawing to an arbitrary square in
+// the world can thereafter be drawn to using coordinates 0 - 1.
+inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
+ glScalef(square.size_, square.size_, 1.0f);
+ glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f);
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h
new file mode 100644
index 0000000000..18123cef01
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h
@@ -0,0 +1,642 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+template <typename T>
+Image<T>::Image(const int width, const int height)
+ : width_less_one_(width - 1),
+ height_less_one_(height - 1),
+ data_size_(width * height),
+ own_data_(true),
+ width_(width),
+ height_(height),
+ stride_(width) {
+ Allocate();
+}
+
+template <typename T>
+Image<T>::Image(const Size& size)
+ : width_less_one_(size.width - 1),
+ height_less_one_(size.height - 1),
+ data_size_(size.width * size.height),
+ own_data_(true),
+ width_(size.width),
+ height_(size.height),
+ stride_(size.width) {
+ Allocate();
+}
+
+// Constructor that creates an image from preallocated data.
+// Note: The image takes ownership of the data lifecycle, unless own_data is
+// set to false.
+template <typename T>
+Image<T>::Image(const int width, const int height, T* const image_data,
+ const bool own_data) :
+ width_less_one_(width - 1),
+ height_less_one_(height - 1),
+ data_size_(width * height),
+ own_data_(own_data),
+ width_(width),
+ height_(height),
+ stride_(width) {
+ image_data_ = image_data;
+ SCHECK(image_data_ != NULL, "Can't create image with NULL data!");
+}
+
+template <typename T>
+Image<T>::~Image() {
+ if (own_data_) {
+ delete[] image_data_;
+ }
+ image_data_ = NULL;
+}
+
+template<typename T>
+template<class DstType>
+bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x,
+ const int fp_y,
+ const int patchwidth,
+ const int patchheight,
+ DstType* to_data) const {
+ // Calculate weights.
+ const int trunc_x = fp_x >> 16;
+ const int trunc_y = fp_y >> 16;
+
+ if (trunc_x < 0 || trunc_y < 0 ||
+ (trunc_x + patchwidth) >= width_less_one_ ||
+ (trunc_y + patchheight) >= height_less_one_) {
+ return false;
+ }
+
+ // Now walk over destination patch and fill from interpolated source image.
+ for (int y = 0; y < patchheight; ++y, to_data += patchwidth) {
+ for (int x = 0; x < patchwidth; ++x) {
+ to_data[x] =
+ static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16),
+ fp_y + (y << 16)));
+ }
+ }
+
+ return true;
+}
+
+template <typename T>
+Image<T>* Image<T>::Crop(
+ const int left, const int top, const int right, const int bottom) const {
+ SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left);
+ SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right);
+ SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top);
+ SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom);
+
+ SCHECK(left <= right, "mismatch!");
+ SCHECK(top <= bottom, "mismatch!");
+
+ const int new_width = right - left + 1;
+ const int new_height = bottom - top + 1;
+
+ Image<T>* const cropped_image = new Image(new_width, new_height);
+
+ for (int y = 0; y < new_height; ++y) {
+ memcpy((*cropped_image)[y], ((*this)[y + top] + left),
+ new_width * sizeof(T));
+ }
+
+ return cropped_image;
+}
+
+template <typename T>
+inline float Image<T>::GetPixelInterp(const float x, const float y) const {
+ // Do int conversion one time.
+ const int floored_x = static_cast<int>(x);
+ const int floored_y = static_cast<int>(y);
+
+ // Note: it might be the case that the *_[min|max] values are clipped, and
+ // these (the a b c d vals) aren't (for speed purposes), but that doesn't
+ // matter. We'll just be blending the pixel with itself in that case anyway.
+ const float b = x - floored_x;
+ const float a = 1.0f - b;
+
+ const float d = y - floored_y;
+ const float c = 1.0f - d;
+
+ SCHECK(ValidInterpPixel(x, y),
+ "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)",
+ x, width_less_one_, y, height_less_one_);
+
+ const T* const pix_ptr = (*this)[floored_y] + floored_x;
+
+ // Get the pixel values surrounding this point.
+ const T& p1 = pix_ptr[0];
+ const T& p2 = pix_ptr[1];
+ const T& p3 = pix_ptr[width_];
+ const T& p4 = pix_ptr[width_ + 1];
+
+ // Simple bilinear interpolation between four reference pixels.
+ // If x is the value requested:
+ // a b
+ // -------
+ // c |p1 p2|
+ // | x |
+ // d |p3 p4|
+ // -------
+ return c * ((a * p1) + (b * p2)) +
+ d * ((a * p3) + (b * p4));
+}
+
+
+template <typename T>
+inline T Image<T>::GetPixelInterpFixed1616(
+ const int fp_x_whole, const int fp_y_whole) const {
+ static const int kFixedPointOne = 0x00010000;
+ static const int kFixedPointHalf = 0x00008000;
+ static const int kFixedPointTruncateMask = 0xFFFF0000;
+
+ int trunc_x = fp_x_whole & kFixedPointTruncateMask;
+ int trunc_y = fp_y_whole & kFixedPointTruncateMask;
+ const int fp_x = fp_x_whole - trunc_x;
+ const int fp_y = fp_y_whole - trunc_y;
+
+ // Scale the truncated values back to regular ints.
+ trunc_x >>= 16;
+ trunc_y >>= 16;
+
+ const int one_minus_fp_x = kFixedPointOne - fp_x;
+ const int one_minus_fp_y = kFixedPointOne - fp_y;
+
+ const T* trunc_start = (*this)[trunc_y] + trunc_x;
+
+ const T a = trunc_start[0];
+ const T b = trunc_start[1];
+ const T c = trunc_start[stride_];
+ const T d = trunc_start[stride_ + 1];
+
+ return ((one_minus_fp_y * static_cast<int64>(one_minus_fp_x * a + fp_x * b) +
+ fp_y * static_cast<int64>(one_minus_fp_x * c + fp_x * d) +
+ kFixedPointHalf) >> 32);
+}
+
+template <typename T>
+inline bool Image<T>::ValidPixel(const int x, const int y) const {
+ return InRange(x, ZERO, width_less_one_) &&
+ InRange(y, ZERO, height_less_one_);
+}
+
+template <typename T>
+inline BoundingBox Image<T>::GetContainingBox() const {
+ return BoundingBox(
+ 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON);
+}
+
+template <typename T>
+inline bool Image<T>::Contains(const BoundingBox& bounding_box) const {
+ // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds
+ // are ok.
+ return GetContainingBox().Contains(bounding_box);
+}
+
+template <typename T>
+inline bool Image<T>::ValidInterpPixel(const float x, const float y) const {
+ // Exclusive of max because we can be more efficient if we don't handle
+ // interpolating on or past the last pixel.
+ return (x >= ZERO) && (x < width_less_one_) &&
+ (y >= ZERO) && (y < height_less_one_);
+}
+
+template <typename T>
+void Image<T>::DownsampleAveraged(const T* const original, const int stride,
+ const int factor) {
+#ifdef __ARM_NEON
+ if (factor == 4 || factor == 2) {
+ DownsampleAveragedNeon(original, stride, factor);
+ return;
+ }
+#endif
+
+ // TODO(andrewharp): delete or enable this for non-uint8 downsamples.
+ const int pixels_per_block = factor * factor;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ const int orig_y = y * factor;
+ const int y_bound = orig_y + factor;
+
+ // Sum up the original pixels.
+ for (int x = 0; x < width_; ++x) {
+ const int orig_x = x * factor;
+ const int x_bound = orig_x + factor;
+
+ // Making this int32 because type U or T might overflow.
+ int32 pixel_sum = 0;
+
+ // Grab all the pixels that make up this pixel.
+ for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) {
+ const T* p = original + curr_y * stride + orig_x;
+
+ for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) {
+ pixel_sum += *p++;
+ }
+ }
+
+ (*this)[y][x] = pixel_sum / pixels_per_block;
+ }
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) {
+ // Calculating the scaling factors based on target image size.
+ const float factor_x = static_cast<float>(original.GetWidth()) /
+ static_cast<float>(width_);
+ const float factor_y = static_cast<float>(original.GetHeight()) /
+ static_cast<float>(height_);
+
+ // Calculating initial offset in x-axis.
+ const float offset_x = 0.5f * (original.GetWidth() - width_) / width_;
+
+ // Calculating initial offset in y-axis.
+ const float offset_y = 0.5f * (original.GetHeight() - height_) / height_;
+
+ float orig_y = offset_y;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ float orig_x = offset_x;
+
+ // Finding nearest pixel on y-axis.
+ const int nearest_y = static_cast<int>(orig_y + 0.5f);
+ const T* row_data = original[nearest_y];
+
+ T* pixel_ptr = (*this)[y];
+
+ for (int x = 0; x < width_; ++x) {
+ // Finding nearest pixel on x-axis.
+ const int nearest_x = static_cast<int>(orig_x + 0.5f);
+
+ *pixel_ptr++ = row_data[nearest_x];
+
+ orig_x += factor_x;
+ }
+
+ orig_y += factor_y;
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) {
+ // TODO(andrewharp): Turn this into a general compare sizes/bulk
+ // copy method.
+ if (original.GetWidth() == GetWidth() &&
+ original.GetHeight() == GetHeight() &&
+ original.stride() == stride()) {
+ memcpy(image_data_, original.data(), data_size_ * sizeof(T));
+ return;
+ }
+
+ // Calculating the scaling factors based on target image size.
+ const float factor_x = static_cast<float>(original.GetWidth()) /
+ static_cast<float>(width_);
+ const float factor_y = static_cast<float>(original.GetHeight()) /
+ static_cast<float>(height_);
+
+ // Calculating initial offset in x-axis.
+ const float offset_x = 0;
+ const int offset_x_fp = RealToFixed1616(offset_x);
+
+ // Calculating initial offset in y-axis.
+ const float offset_y = 0;
+ const int offset_y_fp = RealToFixed1616(offset_y);
+
+ // Get the fixed point scaling factor value.
+ // Shift by 8 so we can fit everything into a 4 byte int later for speed
+ // reasons. This means the precision is limited to 1 / 256th of a pixel,
+ // but this should be good enough.
+ const int factor_x_fp = RealToFixed1616(factor_x) >> 8;
+ const int factor_y_fp = RealToFixed1616(factor_y) >> 8;
+
+ int src_y_fp = offset_y_fp >> 8;
+
+ static const int kFixedPointOne8 = 0x00000100;
+ static const int kFixedPointHalf8 = 0x00000080;
+ static const int kFixedPointTruncateMask8 = 0xFFFFFF00;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ int src_x_fp = offset_x_fp >> 8;
+
+ int trunc_y = src_y_fp & kFixedPointTruncateMask8;
+ const int fp_y = src_y_fp - trunc_y;
+
+ // Scale the truncated values back to regular ints.
+ trunc_y >>= 8;
+
+ const int one_minus_fp_y = kFixedPointOne8 - fp_y;
+
+ T* pixel_ptr = (*this)[y];
+
+ // Make sure not to read from an invalid row.
+ const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1);
+ const T* other_top_ptr = original[trunc_y];
+ const T* other_bot_ptr = original[trunc_y_b];
+
+ int last_trunc_x = -1;
+ int trunc_x = -1;
+
+ T a = 0;
+ T b = 0;
+ T c = 0;
+ T d = 0;
+
+ for (int x = 0; x < width_; ++x) {
+ trunc_x = src_x_fp & kFixedPointTruncateMask8;
+
+ const int fp_x = (src_x_fp - trunc_x) >> 8;
+
+ // Scale the truncated values back to regular ints.
+ trunc_x >>= 8;
+
+ // It's possible we're reading from the same pixels
+ if (trunc_x != last_trunc_x) {
+ // Make sure not to read from an invalid column.
+ const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1);
+ a = other_top_ptr[trunc_x];
+ b = other_top_ptr[trunc_x_b];
+ c = other_bot_ptr[trunc_x];
+ d = other_bot_ptr[trunc_x_b];
+ last_trunc_x = trunc_x;
+ }
+
+ const int one_minus_fp_x = kFixedPointOne8 - fp_x;
+
+ const int32 value =
+ ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) +
+ (fp_y * one_minus_fp_x * c + fp_x * d) +
+ kFixedPointHalf8) >> 16;
+
+ *pixel_ptr++ = value;
+
+ src_x_fp += factor_x_fp;
+ }
+ src_y_fp += factor_y_fp;
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) {
+ for (int y = 0; y < height_; ++y) {
+ const int orig_y = Clip(2 * y, ZERO, original.height_less_one_);
+ const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_);
+ const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_);
+
+ for (int x = 0; x < width_; ++x) {
+ const int orig_x = Clip(2 * x, ZERO, original.width_less_one_);
+ const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_);
+ const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_);
+
+ // Center.
+ int32 pixel_sum = original[orig_y][orig_x] * 4;
+
+ // Sides.
+ pixel_sum += (original[orig_y][max_x] +
+ original[orig_y][min_x] +
+ original[max_y][orig_x] +
+ original[min_y][orig_x]) * 2;
+
+ // Diagonals.
+ pixel_sum += (original[min_y][max_x] +
+ original[min_y][min_x] +
+ original[max_y][max_x] +
+ original[max_y][min_x]);
+
+ (*this)[y][x] = pixel_sum >> 4; // 16
+ }
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) {
+ const int max_x = original.width_less_one_;
+ const int max_y = original.height_less_one_;
+
+ // The JY Bouget paper on Lucas-Kanade recommends a
+ // [1/16 1/4 3/8 1/4 1/16]^2 filter.
+ // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below.
+ static const int window_radius = 2;
+ static const int window_size = window_radius*2 + 1;
+ static const int window_weights[] = {1, 4, 6, 4, 1, // 16 +
+ 4, 16, 24, 16, 4, // 64 +
+ 6, 24, 36, 24, 6, // 96 +
+ 4, 16, 24, 16, 4, // 64 +
+ 1, 4, 6, 4, 1}; // 16 = 256
+
+ // We'll multiply and sum with the the whole numbers first, then divide by
+ // the total weight to normalize at the last moment.
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ int32 pixel_sum = 0;
+
+ const int* w = window_weights;
+ const int start_x = Clip((x << 1) - window_radius, ZERO, max_x);
+
+ // Clip the boundaries to the size of the image.
+ for (int window_y = 0; window_y < window_size; ++window_y) {
+ const int start_y =
+ Clip((y << 1) - window_radius + window_y, ZERO, max_y);
+
+ const T* p = original[start_y] + start_x;
+
+ for (int window_x = 0; window_x < window_size; ++window_x) {
+ pixel_sum += *p++ * *w++;
+ }
+ }
+
+ // Conversion to type T will happen here after shifting right 8 bits to
+ // divide by 256.
+ (*this)[y][x] = pixel_sum >> 8;
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ScharrPixelX(const Image<U>& original,
+ const int center_x, const int center_y) const {
+ const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_);
+ const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_);
+ const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_);
+ const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_);
+
+ // Convolution loop unrolled for performance...
+ return (3 * (original[min_y][max_x]
+ + original[max_y][max_x]
+ - original[min_y][min_x]
+ - original[max_y][min_x])
+ + 10 * (original[center_y][max_x]
+ - original[center_y][min_x])) / 32;
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ScharrPixelY(const Image<U>& original,
+ const int center_x, const int center_y) const {
+ const int min_x = Clip(center_x - 1, 0, original.width_less_one_);
+ const int max_x = Clip(center_x + 1, 0, original.width_less_one_);
+ const int min_y = Clip(center_y - 1, 0, original.height_less_one_);
+ const int max_y = Clip(center_y + 1, 0, original.height_less_one_);
+
+ // Convolution loop unrolled for performance...
+ return (3 * (original[max_y][min_x]
+ + original[max_y][max_x]
+ - original[min_y][min_x]
+ - original[min_y][max_x])
+ + 10 * (original[max_y][center_x]
+ - original[min_y][center_x])) / 32;
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::ScharrX(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ScharrPixelX(original, x, y));
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::ScharrY(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ScharrPixelY(original, x, y));
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+void Image<T>::DerivativeX(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ const U* const source_row = original[y];
+ T* const dest_row = (*this)[y];
+
+ // Compute first pixel. Approximated with forward difference.
+ dest_row[0] = source_row[1] - source_row[0];
+
+ // All the pixels in between. Central difference method.
+ const U* source_prev_pixel = source_row;
+ T* dest_pixel = dest_row + 1;
+ const U* source_next_pixel = source_row + 2;
+ for (int x = 1; x < width_less_one_; ++x) {
+ *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
+ }
+
+ // Last pixel. Approximated with backward difference.
+ dest_row[width_less_one_] =
+ source_row[width_less_one_] - source_row[width_less_one_ - 1];
+ }
+}
+
+template <typename T>
+template <typename U>
+void Image<T>::DerivativeY(const Image<U>& original) {
+ const int src_stride = original.stride();
+
+ // Compute 1st row. Approximated with forward difference.
+ {
+ const U* const src_row = original[0];
+ T* dest_row = (*this)[0];
+ for (int x = 0; x < width_; ++x) {
+ dest_row[x] = src_row[x + src_stride] - src_row[x];
+ }
+ }
+
+ // Compute all rows in between using central difference.
+ for (int y = 1; y < height_less_one_; ++y) {
+ T* dest_row = (*this)[y];
+
+ const U* source_prev_pixel = original[y - 1];
+ const U* source_next_pixel = original[y + 1];
+ for (int x = 0; x < width_; ++x) {
+ *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
+ }
+ }
+
+ // Compute last row. Approximated with backward difference.
+ {
+ const U* const src_row = original[height_less_one_];
+ T* dest_row = (*this)[height_less_one_];
+ for (int x = 0; x < width_; ++x) {
+ dest_row[x] = src_row[x] - src_row[x - src_stride];
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ConvolvePixel3x3(const Image<U>& original,
+ const int* const filter,
+ const int center_x, const int center_y,
+ const int total) const {
+ int32 sum = 0;
+ for (int filter_y = 0; filter_y < 3; ++filter_y) {
+ const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight());
+ for (int filter_x = 0; filter_x < 3; ++filter_x) {
+ const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth());
+ sum += original[y][x] * filter[filter_y * 3 + filter_x];
+ }
+ }
+ return sum / total;
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::Convolve3x3(const Image<U>& original,
+ const int32* const filter) {
+ int32 sum = 0;
+ for (int i = 0; i < 9; ++i) {
+ sum += abs(filter[i]);
+ }
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum));
+ }
+ }
+}
+
+template <typename T>
+inline void Image<T>::FromArray(const T* const pixels, const int stride,
+ const int factor) {
+ if (factor == 1 && stride == width_) {
+ // If not subsampling, memcpy per line should be faster.
+ memcpy(this->image_data_, pixels, data_size_ * sizeof(T));
+ return;
+ }
+
+ DownsampleAveraged(pixels, stride, factor);
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h
new file mode 100644
index 0000000000..29b0adbda8
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image.h
@@ -0,0 +1,346 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+// TODO(andrewharp): Make this a cast to uint32 if/when we go unsigned for
+// operations.
+#define ZERO 0
+
+#ifdef SANITY_CHECKS
+ #define CHECK_PIXEL(IMAGE, X, Y) {\
+ SCHECK((IMAGE)->ValidPixel((X), (Y)), \
+ "CHECK_PIXEL(%d,%d) in %dx%d image.", \
+ static_cast<int>(X), static_cast<int>(Y), \
+ (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
+ }
+
+ #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\
+ SCHECK((IMAGE)->validInterpPixel((X), (Y)), \
+ "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \
+ static_cast<float>(X), static_cast<float>(Y), \
+ (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
+ }
+#else
+ #define CHECK_PIXEL(image, x, y) {}
+ #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {}
+#endif
+
+namespace tf_tracking {
+
+#ifdef SANITY_CHECKS
+// Class which exists solely to provide bounds checking for array-style image
+// data access.
+template <typename T>
+class RowData {
+ public:
+ RowData(T* const row_data, const int max_col)
+ : row_data_(row_data), max_col_(max_col) {}
+
+ inline T& operator[](const int col) const {
+ SCHECK(InRange(col, 0, max_col_),
+ "Column out of range: %d (%d max)", col, max_col_);
+ return row_data_[col];
+ }
+
+ inline operator T*() const {
+ return row_data_;
+ }
+
+ private:
+ T* const row_data_;
+ const int max_col_;
+};
+#endif
+
+// Naive templated sorting function.
+template <typename T>
+int Comp(const void* a, const void* b) {
+ const T val1 = *reinterpret_cast<const T*>(a);
+ const T val2 = *reinterpret_cast<const T*>(b);
+
+ if (val1 == val2) {
+ return 0;
+ } else if (val1 < val2) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+// TODO(andrewharp): Make explicit which operations support negative numbers or
+// struct/class types in image data (possibly create fast multi-dim array class
+// for data where pixel arithmetic does not make sense).
+
+// Image class optimized for working on numeric arrays as grayscale image data.
+// Supports other data types as a 2D array class, so long as no pixel math
+// operations are called (convolution, downsampling, etc).
+template <typename T>
+class Image {
+ public:
+ Image(const int width, const int height);
+ explicit Image(const Size& size);
+
+ // Constructor that creates an image from preallocated data.
+ // Note: The image takes ownership of the data lifecycle, unless own_data is
+ // set to false.
+ Image(const int width, const int height, T* const image_data,
+ const bool own_data = true);
+
+ ~Image();
+
+ // Extract a pixel patch from this image, starting at a subpixel location.
+ // Uses 16:16 fixed point format for representing real values and doing the
+ // bilinear interpolation.
+ //
+ // Arguments fp_x and fp_y tell the subpixel position in fixed point format,
+ // patchwidth/patchheight give the size of the patch in pixels and
+ // to_data must be a valid pointer to a *contiguous* destination data array.
+ template<class DstType>
+ bool ExtractPatchAtSubpixelFixed1616(const int fp_x,
+ const int fp_y,
+ const int patchwidth,
+ const int patchheight,
+ DstType* to_data) const;
+
+ Image<T>* Crop(
+ const int left, const int top, const int right, const int bottom) const;
+
+ inline int GetWidth() const { return width_; }
+ inline int GetHeight() const { return height_; }
+
+ // Bilinearly sample a value between pixels. Values must be within the image.
+ inline float GetPixelInterp(const float x, const float y) const;
+
+ // Bilinearly sample a pixels at a subpixel position using fixed point
+ // arithmetic.
+ // Avoids float<->int conversions.
+ // Values must be within the image.
+ // Arguments fp_x and fp_y tell the subpixel position in
+ // 16:16 fixed point format.
+ //
+ // Important: This function only makes sense for integer-valued images, such
+ // as Image<uint8> or Image<int> etc.
+ inline T GetPixelInterpFixed1616(const int fp_x_whole,
+ const int fp_y_whole) const;
+
+ // Returns true iff the pixel is in the image's boundaries.
+ inline bool ValidPixel(const int x, const int y) const;
+
+ inline BoundingBox GetContainingBox() const;
+
+ inline bool Contains(const BoundingBox& bounding_box) const;
+
+ inline T GetMedianValue() {
+ qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>);
+ return image_data_[data_size_ >> 1];
+ }
+
+ // Returns true iff the pixel is in the image's boundaries for interpolation
+ // purposes.
+ // TODO(andrewharp): check in interpolation follow-up change.
+ inline bool ValidInterpPixel(const float x, const float y) const;
+
+ // Safe lookup with boundary enforcement.
+ inline T GetPixelClipped(const int x, const int y) const {
+ return (*this)[Clip(y, ZERO, height_less_one_)]
+ [Clip(x, ZERO, width_less_one_)];
+ }
+
+#ifdef SANITY_CHECKS
+ inline RowData<T> operator[](const int row) {
+ SCHECK(InRange(row, 0, height_less_one_),
+ "Row out of range: %d (%d max)", row, height_less_one_);
+ return RowData<T>(image_data_ + row * stride_, width_less_one_);
+ }
+
+ inline const RowData<T> operator[](const int row) const {
+ SCHECK(InRange(row, 0, height_less_one_),
+ "Row out of range: %d (%d max)", row, height_less_one_);
+ return RowData<T>(image_data_ + row * stride_, width_less_one_);
+ }
+#else
+ inline T* operator[](const int row) {
+ return image_data_ + row * stride_;
+ }
+
+ inline const T* operator[](const int row) const {
+ return image_data_ + row * stride_;
+ }
+#endif
+
+ const T* data() const { return image_data_; }
+
+ inline int stride() const { return stride_; }
+
+ // Clears image to a single value.
+ inline void Clear(const T& val) {
+ memset(image_data_, val, sizeof(*image_data_) * data_size_);
+ }
+
+#ifdef __ARM_NEON
+ void Downsample2x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x);
+
+ void Downsample4x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x);
+
+ void DownsampleAveragedNeon(const uint8* const original, const int stride,
+ const int factor);
+#endif
+
+ // Naive downsampler that reduces image size by factor by averaging pixels in
+ // blocks of size factor x factor.
+ void DownsampleAveraged(const T* const original, const int stride,
+ const int factor);
+
+ // Naive downsampler that reduces image size by factor by averaging pixels in
+ // blocks of size factor x factor.
+ inline void DownsampleAveraged(const Image<T>& original, const int factor) {
+ DownsampleAveraged(original.data(), original.GetWidth(), factor);
+ }
+
+ // Native downsampler that reduces image size using nearest interpolation
+ void DownsampleInterpolateNearest(const Image<T>& original);
+
+ // Native downsampler that reduces image size using fixed-point bilinear
+ // interpolation
+ void DownsampleInterpolateLinear(const Image<T>& original);
+
+ // Relatively efficient downsampling of an image by a factor of two with a
+ // low-pass 3x3 smoothing operation thrown in.
+ void DownsampleSmoothed3x3(const Image<T>& original);
+
+ // Relatively efficient downsampling of an image by a factor of two with a
+ // low-pass 5x5 smoothing operation thrown in.
+ void DownsampleSmoothed5x5(const Image<T>& original);
+
+ // Optimized Scharr filter on a single pixel in the X direction.
+ // Scharr filters are like central-difference operators, but have more
+ // rotational symmetry in their response because they also consider the
+ // diagonal neighbors.
+ template <typename U>
+ inline T ScharrPixelX(const Image<U>& original,
+ const int center_x, const int center_y) const;
+
+ // Optimized Scharr filter on a single pixel in the X direction.
+ // Scharr filters are like central-difference operators, but have more
+ // rotational symmetry in their response because they also consider the
+ // diagonal neighbors.
+ template <typename U>
+ inline T ScharrPixelY(const Image<U>& original,
+ const int center_x, const int center_y) const;
+
+ // Convolve the image with a Scharr filter in the X direction.
+ // Much faster than an equivalent generic convolution.
+ template <typename U>
+ inline void ScharrX(const Image<U>& original);
+
+ // Convolve the image with a Scharr filter in the Y direction.
+ // Much faster than an equivalent generic convolution.
+ template <typename U>
+ inline void ScharrY(const Image<U>& original);
+
+ static inline T HalfDiff(int32 first, int32 second) {
+ return (second - first) / 2;
+ }
+
+ template <typename U>
+ void DerivativeX(const Image<U>& original);
+
+ template <typename U>
+ void DerivativeY(const Image<U>& original);
+
+ // Generic function for convolving pixel with 3x3 filter.
+ // Filter pixels should be in row major order.
+ template <typename U>
+ inline T ConvolvePixel3x3(const Image<U>& original,
+ const int* const filter,
+ const int center_x, const int center_y,
+ const int total) const;
+
+ // Generic function for convolving an image with a 3x3 filter.
+ // TODO(andrewharp): Generalize this for any size filter.
+ template <typename U>
+ inline void Convolve3x3(const Image<U>& original,
+ const int32* const filter);
+
+ // Load this image's data from a data array. The data at pixels is assumed to
+ // have dimensions equivalent to this image's dimensions * factor.
+ inline void FromArray(const T* const pixels, const int stride,
+ const int factor = 1);
+
+ // Copy the image back out to an appropriately sized data array.
+ inline void ToArray(T* const pixels) const {
+ // If not subsampling, memcpy should be faster.
+ memcpy(pixels, this->image_data_, data_size_ * sizeof(T));
+ }
+
+ // Precompute these for efficiency's sake as they're used by a lot of
+ // clipping code and loop code.
+ // TODO(andrewharp): make these only accessible by other Images.
+ const int width_less_one_;
+ const int height_less_one_;
+
+ // The raw size of the allocated data.
+ const int data_size_;
+
+ private:
+ inline void Allocate() {
+ image_data_ = new T[data_size_];
+ if (image_data_ == NULL) {
+ LOGE("Couldn't allocate image data!");
+ }
+ }
+
+ T* image_data_;
+
+ bool own_data_;
+
+ const int width_;
+ const int height_;
+
+ // The image stride (offset to next row).
+ // TODO(andrewharp): Make sure that stride is honored in all code.
+ const int stride_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Image);
+};
+
+template <typename t>
+inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
+ for (int y = 0; y < image.GetHeight(); ++y) {
+ for (int x = 0; x < image.GetWidth(); ++x) {
+ stream << image[y][x] << " ";
+ }
+ stream << std::endl;
+ }
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h
new file mode 100644
index 0000000000..16b1864ee6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_data.h
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+// Class that encapsulates all bulky processed data for a frame.
+class ImageData {
+ public:
+ explicit ImageData(const int width, const int height)
+ : uv_frame_width_(width << 1),
+ uv_frame_height_(height << 1),
+ timestamp_(0),
+ image_(width, height) {
+ InitPyramid(width, height);
+ ResetComputationCache();
+ }
+
+ private:
+ void ResetComputationCache() {
+ uv_data_computed_ = false;
+ integral_image_computed_ = false;
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ spatial_x_computed_[i] = false;
+ spatial_y_computed_[i] = false;
+ pyramid_sqrt2_computed_[i * 2] = false;
+ pyramid_sqrt2_computed_[i * 2 + 1] = false;
+ }
+ }
+
+ void InitPyramid(const int width, const int height) {
+ int level_width = width;
+ int level_height = height;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ pyramid_sqrt2_[i * 2] = NULL;
+ pyramid_sqrt2_[i * 2 + 1] = NULL;
+ spatial_x_[i] = NULL;
+ spatial_y_[i] = NULL;
+
+ level_width /= 2;
+ level_height /= 2;
+ }
+
+ // Alias the first pyramid level to image_.
+ pyramid_sqrt2_[0] = &image_;
+ }
+
+ public:
+ ~ImageData() {
+ // The first pyramid level is actually an alias to image_,
+ // so make sure it doesn't get deleted here.
+ pyramid_sqrt2_[0] = NULL;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ SAFE_DELETE(pyramid_sqrt2_[i * 2]);
+ SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
+ SAFE_DELETE(spatial_x_[i]);
+ SAFE_DELETE(spatial_y_[i]);
+ }
+ }
+
+ void SetData(const uint8* const new_frame, const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ SetData(new_frame, NULL, stride, timestamp, downsample_factor);
+ }
+
+ void SetData(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ ResetComputationCache();
+
+ timestamp_ = timestamp;
+
+ TimeLog("SetData!");
+
+ pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
+ pyramid_sqrt2_computed_[0] = true;
+ TimeLog("Downsampled image");
+
+ if (uv_frame != NULL) {
+ if (u_data_.get() == NULL) {
+ u_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ v_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ }
+
+ GetUV(uv_frame, u_data_.get(), v_data_.get());
+ uv_data_computed_ = true;
+ TimeLog("Copied UV data");
+ } else {
+ LOGV("No uv data!");
+ }
+
+#ifdef LOG_TIME
+ // If profiling is enabled, precompute here to make it easier to distinguish
+ // total costs.
+ Precompute();
+#endif
+ }
+
+ inline const uint64 GetTimestamp() const {
+ return timestamp_;
+ }
+
+ inline const Image<uint8>* GetImage() const {
+ SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
+ return pyramid_sqrt2_[0];
+ }
+
+ const Image<uint8>* GetPyramidSqrt2Level(const int level) const {
+ if (!pyramid_sqrt2_computed_[level]) {
+ SCHECK(level != 0, "Level equals 0!");
+ if (level == 1) {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(0);
+ if (pyramid_sqrt2_[level] == NULL) {
+ const int new_width =
+ (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
+ const int new_height =
+ (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
+ 2;
+
+ pyramid_sqrt2_[level] = new Image<uint8>(new_width, new_height);
+ }
+ pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
+ } else {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(level - 2);
+ if (pyramid_sqrt2_[level] == NULL) {
+ pyramid_sqrt2_[level] = new Image<uint8>(
+ upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
+ }
+ pyramid_sqrt2_[level]->DownsampleAveraged(
+ upper_level.data(), upper_level.stride(), 2);
+ }
+ pyramid_sqrt2_computed_[level] = true;
+ }
+ return pyramid_sqrt2_[level];
+ }
+
+ inline const Image<int32>* GetSpatialX(const int level) const {
+ if (!spatial_x_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_x_[level] == NULL) {
+ spatial_x_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_x_[level]->DerivativeX(src);
+ spatial_x_computed_[level] = true;
+ }
+ return spatial_x_[level];
+ }
+
+ inline const Image<int32>* GetSpatialY(const int level) const {
+ if (!spatial_y_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_y_[level] == NULL) {
+ spatial_y_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_y_[level]->DerivativeY(src);
+ spatial_y_computed_[level] = true;
+ }
+ return spatial_y_[level];
+ }
+
+ // The integral image is currently only used for object detection, so lazily
+ // initialize it on request.
+ inline const IntegralImage* GetIntegralImage() const {
+ if (integral_image_.get() == NULL) {
+ integral_image_.reset(new IntegralImage(image_));
+ } else if (!integral_image_computed_) {
+ integral_image_->Recompute(image_);
+ }
+ integral_image_computed_ = true;
+ return integral_image_.get();
+ }
+
+ inline const Image<uint8>* GetU() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return u_data_.get();
+ }
+
+ inline const Image<uint8>* GetV() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return v_data_.get();
+ }
+
+ private:
+ void Precompute() {
+ // Create the smoothed pyramids.
+ for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed pyramids");
+
+ // Create the smoothed pyramids.
+ for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed sqrt pyramids");
+
+ // Create the spatial derivatives for frame 1.
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ (void) GetSpatialX(i);
+ (void) GetSpatialY(i);
+ }
+ TimeLog("Created spatial derivatives");
+
+ (void) GetIntegralImage();
+ TimeLog("Got integral image!");
+ }
+
+ const int uv_frame_width_;
+ const int uv_frame_height_;
+
+ int64 timestamp_;
+
+ Image<uint8> image_;
+
+ bool uv_data_computed_;
+ std::unique_ptr<Image<uint8> > u_data_;
+ std::unique_ptr<Image<uint8> > v_data_;
+
+ mutable bool spatial_x_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_x_[kNumPyramidLevels];
+
+ mutable bool spatial_y_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_y_[kNumPyramidLevels];
+
+ // Mutable so the lazy initialization can work when this class is const.
+ // Whether or not the integral image has been computed for the current image.
+ mutable bool integral_image_computed_;
+ mutable std::unique_ptr<IntegralImage> integral_image_;
+
+ mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
+ mutable Image<uint8>* pyramid_sqrt2_[kNumPyramidLevels * 2];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_neon.cc b/tensorflow/examples/android/jni/object_tracking/image_neon.cc
new file mode 100644
index 0000000000..ddd8447bf3
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_neon.cc
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NEON implementations of Image methods for compatible devices. Control
+// should never enter this compilation unit on incompatible devices.
+
+#ifdef __ARM_NEON
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+// This function does the bulk of the work.
+template <>
+void Image<uint8>::Downsample2x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x) {
+ // Divide input x offset by 2 to find output offset.
+ const int new_x = orig_x >> 1;
+
+ // Initial offset into top row.
+ const uint8* offset = original + orig_x;
+
+ // This points to the leftmost pixel of our 8 horizontally arranged
+ // pixels in the destination data.
+ uint8* ptr_dst = (*this)[0] + new_x;
+
+ // Sum along vertical columns.
+ // Process 32x2 input pixels and 16x1 output pixels per iteration.
+ for (int new_y = 0; new_y < height_; ++new_y) {
+ uint16x8_t accum1 = vdupq_n_u16(0);
+ uint16x8_t accum2 = vdupq_n_u16(0);
+
+ // Go top to bottom across the four rows of input pixels that make up
+ // this output row.
+ for (int row_num = 0; row_num < 2; ++row_num) {
+ // First 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data1 = vld1q_u8(offset);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum1 = vpadalq_u8(accum1, curr_data1);
+ }
+
+ // Second 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum2 = vpadalq_u8(accum2, curr_data2);
+ }
+
+ // Move offset down one row.
+ offset += stride;
+ }
+
+ // Divide by 4 (number of input pixels per output
+ // pixel) and narrow data from 16 bits per pixel to 8 bpp.
+ const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2);
+ const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2);
+
+ // Concatenate 8x1 pixel strips into 16x1 pixel strip.
+ const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2);
+
+ // Copy all pixels from composite 16x1 vector into output strip.
+ vst1q_u8(ptr_dst, allpixels);
+
+ ptr_dst += stride_;
+ }
+}
+
+// This function does the bulk of the work.
+template <>
+void Image<uint8>::Downsample4x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x) {
+ // Divide input x offset by 4 to find output offset.
+ const int new_x = orig_x >> 2;
+
+ // Initial offset into top row.
+ const uint8* offset = original + orig_x;
+
+ // This points to the leftmost pixel of our 8 horizontally arranged
+ // pixels in the destination data.
+ uint8* ptr_dst = (*this)[0] + new_x;
+
+ // Sum along vertical columns.
+ // Process 32x4 input pixels and 8x1 output pixels per iteration.
+ for (int new_y = 0; new_y < height_; ++new_y) {
+ uint16x8_t accum1 = vdupq_n_u16(0);
+ uint16x8_t accum2 = vdupq_n_u16(0);
+
+ // Go top to bottom across the four rows of input pixels that make up
+ // this output row.
+ for (int row_num = 0; row_num < 4; ++row_num) {
+ // First 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data1 = vld1q_u8(offset);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum1 = vpadalq_u8(accum1, curr_data1);
+ }
+
+ // Second 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum2 = vpadalq_u8(accum2, curr_data2);
+ }
+
+ // Move offset down one row.
+ offset += stride;
+ }
+
+ // Add and widen, then divide by 16 (number of input pixels per output
+ // pixel) and narrow data from 32 bits per pixel to 16 bpp.
+ const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4);
+ const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4);
+
+ // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from
+ // 16 bits to 8 bits per pixel.
+ const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2));
+
+ // Copy all pixels from composite 8x1 vector into output strip.
+ vst1_u8(ptr_dst, allpixels);
+
+ ptr_dst += stride_;
+ }
+}
+
+
+// Hardware accelerated downsampling method for supported devices.
+// Requires that image size be a multiple of 16 pixels in each dimension,
+// and that downsampling be by a factor of 2 or 4.
+template <>
+void Image<uint8>::DownsampleAveragedNeon(const uint8* const original,
+ const int stride, const int factor) {
+ // TODO(andrewharp): stride is a bad approximation for the src image's width.
+ // Better to pass that in directly.
+ SCHECK(width_ * factor <= stride, "Uh oh!");
+ const int last_starting_index = width_ * factor - 32;
+
+ // We process 32 input pixels lengthwise at a time.
+ // The output per pass of this loop is an 8 wide by downsampled height tall
+ // pixel strip.
+ int orig_x = 0;
+ for (; orig_x <= last_starting_index; orig_x += 32) {
+ if (factor == 2) {
+ Downsample2x32ColumnsNeon(original, stride, orig_x);
+ } else {
+ Downsample4x32ColumnsNeon(original, stride, orig_x);
+ }
+ }
+
+ // If a last pass is required, push it to the left enough so that it never
+ // goes out of bounds. This will result in some extra computation on devices
+ // whose frame widths are multiples of 16 and not 32.
+ if (orig_x < last_starting_index + 32) {
+ if (factor == 2) {
+ Downsample2x32ColumnsNeon(original, stride, last_starting_index);
+ } else {
+ Downsample4x32ColumnsNeon(original, stride, last_starting_index);
+ }
+ }
+}
+
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// vals_x should be an array of the window x gradient values, whose indices
+// can be in any order but are parallel to the vals_y entries.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
+void CalculateGNeon(const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G) {
+ const float32_t* const arm_vals_x = (const float32_t*) vals_x;
+ const float32_t* const arm_vals_y = (const float32_t*) vals_y;
+
+ // Running sums.
+ float32x4_t xx = vdupq_n_f32(0.0f);
+ float32x4_t xy = vdupq_n_f32(0.0f);
+ float32x4_t yy = vdupq_n_f32(0.0f);
+
+ // Maximum index we can load 4 consecutive values from.
+ // e.g. if there are 81 values, our last full pass can be from index 77:
+ // 81-4=>77 (77, 78, 79, 80)
+ const int max_i = num_vals - 4;
+
+ // Defined here because we want to keep track of how many values were
+ // processed by NEON, so that we can finish off the remainder the normal
+ // way.
+ int i = 0;
+
+ // Process values 4 at a time, accumulating the sums of
+ // the pixel-wise x*x, x*y, and y*y values.
+ for (; i <= max_i; i += 4) {
+ // Load xs
+ float32x4_t x = vld1q_f32(arm_vals_x + i);
+
+ // Multiply x*x and accumulate.
+ xx = vmlaq_f32(xx, x, x);
+
+ // Load ys
+ float32x4_t y = vld1q_f32(arm_vals_y + i);
+
+ // Multiply x*y and accumulate.
+ xy = vmlaq_f32(xy, x, y);
+
+ // Multiply y*y and accumulate.
+ yy = vmlaq_f32(yy, y, y);
+ }
+
+ static float32_t xx_vals[4];
+ static float32_t xy_vals[4];
+ static float32_t yy_vals[4];
+
+ vst1q_f32(xx_vals, xx);
+ vst1q_f32(xy_vals, xy);
+ vst1q_f32(yy_vals, yy);
+
+ // Accumulated values are store in sets of 4, we have to manually add
+ // the last bits together.
+ for (int j = 0; j < 4; ++j) {
+ G[0] += xx_vals[j];
+ G[1] += xy_vals[j];
+ G[3] += yy_vals[j];
+ }
+
+ // Finishes off last few values (< 4) from above.
+ for (; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+} // namespace tf_tracking
+
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h
new file mode 100644
index 0000000000..5357a9352f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h
@@ -0,0 +1,301 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+inline void GetUV(
+ const uint8* const input, Image<uint8>* const u, Image<uint8>* const v) {
+ const uint8* pUV = input;
+
+ for (int row = 0; row < u->GetHeight(); ++row) {
+ uint8* u_curr = (*u)[row];
+ uint8* v_curr = (*v)[row];
+ for (int col = 0; col < u->GetWidth(); ++col) {
+#ifdef __APPLE__
+ *u_curr++ = *pUV++;
+ *v_curr++ = *pUV++;
+#else
+ *v_curr++ = *pUV++;
+ *u_curr++ = *pUV++;
+#endif
+ }
+ }
+}
+
+// Marks every point within a circle of a given radius on the given boolean
+// image true.
+template <typename U>
+inline static void MarkImage(const int x, const int y, const int radius,
+ Image<U>* const img) {
+ SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y);
+
+ // Precomputed for efficiency.
+ const int squared_radius = Square(radius);
+
+ // Mark every row in the circle.
+ for (int d_y = 0; d_y <= radius; ++d_y) {
+ const int squared_y_dist = Square(d_y);
+
+ const int min_y = MAX(y - d_y, 0);
+ const int max_y = MIN(y + d_y, img->height_less_one_);
+
+ // The max d_x of the circle must be strictly greater or equal to
+ // radius - d_y for any positive d_y. Thus, starting from radius - d_y will
+ // reduce the number of iterations required as compared to starting from
+ // either 0 and counting up or radius and counting down.
+ for (int d_x = radius - d_y; d_x <= radius; ++d_x) {
+ // The first time this critera is met, we know the width of the circle at
+ // this row (without using sqrt).
+ if (squared_y_dist + Square(d_x) >= squared_radius) {
+ const int min_x = MAX(x - d_x, 0);
+ const int max_x = MIN(x + d_x, img->width_less_one_);
+
+ // Mark both above and below the center row.
+ bool* const top_row_start = (*img)[min_y] + min_x;
+ bool* const bottom_row_start = (*img)[max_y] + min_x;
+
+ const int x_width = max_x - min_x + 1;
+ memset(top_row_start, true, sizeof(*top_row_start) * x_width);
+ memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width);
+
+ // This row is marked, time to move on to the next row.
+ break;
+ }
+ }
+ }
+}
+
+#ifdef __ARM_NEON
+void CalculateGNeon(
+ const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G);
+#endif
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// vals_x should be an array of the window x gradient values, whose indices
+// can be in any order but are parallel to the vals_y entries.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
+inline void CalculateG(const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G) {
+#ifdef __ARM_NEON
+ CalculateGNeon(vals_x, vals_y, num_vals, G);
+ return;
+#endif
+
+ // Non-accelerated version.
+ for (int i = 0; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+
+inline void CalculateGInt16(const int16* const vals_x,
+ const int16* const vals_y,
+ const int num_vals, int* const G) {
+ // Non-accelerated version.
+ for (int i = 0; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// Looks up interpolated pixels, then calls above method for implementation.
+inline void CalculateG(const int window_radius,
+ const float center_x, const float center_y,
+ const Image<int32>& I_x, const Image<int32>& I_y,
+ float* const G) {
+ SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!");
+
+ // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels).
+ static const int kMaxWindowRadius = 5;
+ SCHECK(window_radius <= kMaxWindowRadius,
+ "Window %d > %d!", window_radius, kMaxWindowRadius);
+
+ // Diameter of window is 2 * radius + 1 for center pixel.
+ static const int kWindowBufferSize =
+ (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1);
+
+ // Preallocate buffers statically for efficiency.
+ static int16 vals_x[kWindowBufferSize];
+ static int16 vals_y[kWindowBufferSize];
+
+ const int src_left_fixed = RealToFixed1616(center_x - window_radius);
+ const int src_top_fixed = RealToFixed1616(center_y - window_radius);
+
+ int16* vals_x_ptr = vals_x;
+ int16* vals_y_ptr = vals_y;
+
+ const int window_size = 2 * window_radius + 1;
+ for (int y = 0; y < window_size; ++y) {
+ const int fp_y = src_top_fixed + (y << 16);
+
+ for (int x = 0; x < window_size; ++x) {
+ const int fp_x = src_left_fixed + (x << 16);
+
+ *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+
+ int32 g_temp[] = {0, 0, 0, 0};
+ CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp);
+
+ for (int i = 0; i < 4; ++i) {
+ G[i] = g_temp[i];
+ }
+}
+
+inline float ImageCrossCorrelation(const Image<float>& image1,
+ const Image<float>& image2,
+ const int x_offset, const int y_offset) {
+ SCHECK(image1.GetWidth() == image2.GetWidth() &&
+ image1.GetHeight() == image2.GetHeight(),
+ "Dimension mismatch! %dx%d vs %dx%d",
+ image1.GetWidth(), image1.GetHeight(),
+ image2.GetWidth(), image2.GetHeight());
+
+ const int num_pixels = image1.GetWidth() * image1.GetHeight();
+ const float* data1 = image1.data();
+ const float* data2 = image2.data();
+ return ComputeCrossCorrelation(data1, data2, num_pixels);
+}
+
+// Copies an arbitrary region of an image to another (floating point)
+// image, scaling as it goes using bilinear interpolation.
+inline void CopyArea(const Image<uint8>& image,
+ const BoundingBox& area_to_copy,
+ Image<float>* const patch_image) {
+ VLOG(2) << "Copying from: " << area_to_copy << std::endl;
+
+ const int patch_width = patch_image->GetWidth();
+ const int patch_height = patch_image->GetHeight();
+
+ const float x_dist_between_samples = patch_width > 0 ?
+ area_to_copy.GetWidth() / (patch_width - 1) : 0;
+
+ const float y_dist_between_samples = patch_height > 0 ?
+ area_to_copy.GetHeight() / (patch_height - 1) : 0;
+
+ for (int y_index = 0; y_index < patch_height; ++y_index) {
+ const float sample_y =
+ y_index * y_dist_between_samples + area_to_copy.top_;
+
+ for (int x_index = 0; x_index < patch_width; ++x_index) {
+ const float sample_x =
+ x_index * x_dist_between_samples + area_to_copy.left_;
+
+ if (image.ValidInterpPixel(sample_x, sample_y)) {
+ // TODO(andrewharp): Do area averaging when downsampling.
+ (*patch_image)[y_index][x_index] =
+ image.GetPixelInterp(sample_x, sample_y);
+ } else {
+ (*patch_image)[y_index][x_index] = -1.0f;
+ }
+ }
+ }
+}
+
+
+// Takes a floating point image and normalizes it in-place.
+//
+// First, negative values will be set to the mean of the non-negative pixels
+// in the image.
+//
+// Then, the resulting will be normalized such that it has mean value of 0.0 and
+// a standard deviation of 1.0.
+inline void NormalizeImage(Image<float>* const image) {
+ const float* const data_ptr = image->data();
+
+ // Copy only the non-negative values to some temp memory.
+ float running_sum = 0.0f;
+ int num_data_gte_zero = 0;
+ {
+ float* const curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ if (curr_data[i] >= 0.0f) {
+ running_sum += curr_data[i];
+ ++num_data_gte_zero;
+ } else {
+ curr_data[i] = -1.0f;
+ }
+ }
+ }
+
+ // If none of the pixels are valid, just set the entire thing to 0.0f.
+ if (num_data_gte_zero == 0) {
+ image->Clear(0.0f);
+ return;
+ }
+
+ const float corrected_mean = running_sum / num_data_gte_zero;
+
+ float* curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ const float curr_val = *curr_data;
+ *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean;
+ }
+
+ const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f);
+
+ if (std_dev > 0.0f) {
+ curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ *curr_data++ /= std_dev;
+ }
+
+#ifdef SANITY_CHECKS
+ LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev);
+ const float correlation =
+ ComputeCrossCorrelation(image->data(),
+ image->data(),
+ image->data_size_);
+
+ if (std::abs(correlation - 1.0f) > EPSILON) {
+ LOG(ERROR) << "Bad image!" << std::endl;
+ LOG(ERROR) << *image << std::endl;
+ }
+
+ SCHECK(std::abs(correlation - 1.0f) < EPSILON,
+ "Correlation wasn't 1.0f: %.10f", correlation);
+#endif
+ }
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h
new file mode 100755
index 0000000000..28b2045572
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h
@@ -0,0 +1,187 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+typedef uint8 Code;
+
+class IntegralImage: public Image<uint32> {
+ public:
+ explicit IntegralImage(const Image<uint8>& image_base) :
+ Image<uint32>(image_base.GetWidth(), image_base.GetHeight()) {
+ Recompute(image_base);
+ }
+
+ IntegralImage(const int width, const int height) :
+ Image<uint32>(width, height) {}
+
+ void Recompute(const Image<uint8>& image_base) {
+ SCHECK(image_base.GetWidth() == GetWidth() &&
+ image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
+
+ // Sum along first row.
+ {
+ int x_sum = 0;
+ for (int x = 0; x < image_base.GetWidth(); ++x) {
+ x_sum += image_base[0][x];
+ (*this)[0][x] = x_sum;
+ }
+ }
+
+ // Sum everything else.
+ for (int y = 1; y < image_base.GetHeight(); ++y) {
+ uint32* curr_sum = (*this)[y];
+
+ // Previously summed pointers.
+ const uint32* up_one = (*this)[y - 1];
+
+ // Current value pointer.
+ const uint8* curr_delta = image_base[y];
+
+ uint32 row_till_now = 0;
+
+ for (int x = 0; x < GetWidth(); ++x) {
+ // Add the one above and the one to the left.
+ row_till_now += *curr_delta;
+ *curr_sum = *up_one + row_till_now;
+
+ // Scoot everything along.
+ ++curr_sum;
+ ++up_one;
+ ++curr_delta;
+ }
+ }
+
+ SCHECK(VerifyData(image_base), "Images did not match!");
+ }
+
+ bool VerifyData(const Image<uint8>& image_base) {
+ for (int y = 0; y < GetHeight(); ++y) {
+ for (int x = 0; x < GetWidth(); ++x) {
+ uint32 curr_val = (*this)[y][x];
+
+ if (x > 0) {
+ curr_val -= (*this)[y][x - 1];
+ }
+
+ if (y > 0) {
+ curr_val -= (*this)[y - 1][x];
+ }
+
+ if (x > 0 && y > 0) {
+ curr_val += (*this)[y - 1][x - 1];
+ }
+
+ if (curr_val != image_base[y][x]) {
+ LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
+ return false;
+ }
+
+ if (GetRegionSum(x, y, x, y) != curr_val) {
+ LOGE("Mismatch!");
+ }
+ }
+ }
+
+ return true;
+ }
+
+ // Returns the sum of all pixels in the specified region.
+ inline uint32 GetRegionSum(const int x1, const int y1,
+ const int x2, const int y2) const {
+ SCHECK(x1 >= 0 && y1 >= 0 &&
+ x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
+ "indices out of bounds! %d-%d / %d, %d-%d / %d, ",
+ x1, x2, GetWidth(), y1, y2, GetHeight());
+
+ const uint32 everything = (*this)[y2][x2];
+
+ uint32 sum = everything;
+ if (x1 > 0 && y1 > 0) {
+ // Most common case.
+ const uint32 left = (*this)[y2][x1 - 1];
+ const uint32 top = (*this)[y1 - 1][x2];
+ const uint32 top_left = (*this)[y1 - 1][x1 - 1];
+
+ sum = everything - left - top + top_left;
+ SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
+ everything, left, top, top_left, sum, x1, y1, x2, y2);
+ } else if (x1 > 0) {
+ // Flush against top of image.
+ // Subtract out the region to the left only.
+ const uint32 top = (*this)[y2][x1 - 1];
+ sum = everything - top;
+ SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
+ } else if (y1 > 0) {
+ // Flush against left side of image.
+ // Subtract out the region above only.
+ const uint32 left = (*this)[y1 - 1][x2];
+ sum = everything - left;
+ SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
+ }
+
+ SCHECK(sum >= 0, "Negative sum!");
+
+ return sum;
+ }
+
+ // Returns the 2bit code associated with this region, which represents
+ // the overall gradient.
+ inline Code GetCode(const BoundingBox& bounding_box) const {
+ return GetCode(bounding_box.left_, bounding_box.top_,
+ bounding_box.right_, bounding_box.bottom_);
+ }
+
+ inline Code GetCode(const int x1, const int y1,
+ const int x2, const int y2) const {
+ SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
+ x1, y1, x2, y2);
+
+ // Gradient computed vertically.
+ const int box_height = (y2 - y1) / 2;
+ const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
+ const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
+ const bool vertical_code = top_sum > bottom_sum;
+
+ // Gradient computed horizontally.
+ const int box_width = (x2 - x1) / 2;
+ const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
+ const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
+ const bool horizontal_code = left_sum > right_sum;
+
+ const Code final_code = (vertical_code << 1) | horizontal_code;
+
+ SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
+ "Invalid code! %d", final_code);
+
+ // Returns a value 0-3.
+ return final_code;
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
new file mode 100644
index 0000000000..92458536b6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
@@ -0,0 +1,62 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+
+#include <android/log.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+// The JniIntField class is used to access Java fields from native code. This
+// technique of hiding pointers to native objects in opaque Java fields is how
+// the Android hardware libraries work. This reduces the amount of static
+// native methods and makes it easier to manage the lifetime of native objects.
+class JniIntField {
+ public:
+ JniIntField(const char* field_name) : field_name_(field_name), field_ID_(0) {}
+
+ int get(JNIEnv* env, jobject thiz) {
+ if (field_ID_ == 0) {
+ jclass cls = env->GetObjectClass(thiz);
+ CHECK_ALWAYS(cls != 0, "Unable to find class");
+ field_ID_ = env->GetFieldID(cls, field_name_, "I");
+ CHECK_ALWAYS(field_ID_ != 0,
+ "Unable to find field %s. (Check proguard cfg)", field_name_);
+ }
+
+ return env->GetIntField(thiz, field_ID_);
+ }
+
+ void set(JNIEnv* env, jobject thiz, int value) {
+ if (field_ID_ == 0) {
+ jclass cls = env->GetObjectClass(thiz);
+ CHECK_ALWAYS(cls != 0, "Unable to find class");
+ field_ID_ = env->GetFieldID(cls, field_name_, "I");
+ CHECK_ALWAYS(field_ID_ != 0,
+ "Unable to find field %s (Check proguard cfg)", field_name_);
+ }
+
+ env->SetIntField(thiz, field_ID_, value);
+ }
+
+ private:
+ const char* const field_name_;
+
+ // This is just a cache
+ jfieldID field_ID_;
+};
+
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h
new file mode 100644
index 0000000000..82917261cb
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+
+namespace tf_tracking {
+
+// For keeping track of keypoints.
+struct Keypoint {
+ Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {}
+ Keypoint(const float x, const float y)
+ : pos_(x, y), score_(0.0f), type_(0) {}
+
+ Point2f pos_;
+ float score_;
+ uint8 type_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
+ return stream << "[" << keypoint.pos_ << ", "
+ << keypoint.score_ << ", " << keypoint.type_ << "]";
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc
new file mode 100644
index 0000000000..6cc6b4e73f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc
@@ -0,0 +1,549 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Various keypoint detecting functions.
+
+#include <float.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+
+namespace tf_tracking {
+
+static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
+ return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
+}
+
+void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
+ const int num_candidates,
+ Keypoint* const candidate_keypoints) {
+ const Image<int>& I_x = *image_data.GetSpatialX(0);
+ const Image<int>& I_y = *image_data.GetSpatialY(0);
+
+ if (config_->detect_skin) {
+ const Image<uint8>& u_data = *image_data.GetU();
+ const Image<uint8>& v_data = *image_data.GetV();
+
+ static const int reference[] = {111, 155};
+
+ // Score all the keypoints.
+ for (int i = 0; i < num_candidates; ++i) {
+ Keypoint* const keypoint = candidate_keypoints + i;
+
+ const int x_pos = keypoint->pos_.x * 2;
+ const int y_pos = keypoint->pos_.y * 2;
+
+ const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
+ keypoint->score_ =
+ HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
+ GetDistSquaredBetween(reference, curr_color);
+ }
+ } else {
+ // Score all the keypoints.
+ for (int i = 0; i < num_candidates; ++i) {
+ Keypoint* const keypoint = candidate_keypoints + i;
+ keypoint->score_ =
+ HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
+ }
+ }
+}
+
+
+inline int KeypointCompare(const void* const a, const void* const b) {
+ return (reinterpret_cast<const Keypoint*>(a)->score_ -
+ reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
+}
+
+
+// Quicksorts detected keypoints by score.
+void KeypointDetector::SortKeypoints(const int num_candidates,
+ Keypoint* const candidate_keypoints) const {
+ qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
+
+#ifdef SANITY_CHECKS
+ // Verify that the array got sorted.
+ float last_score = FLT_MAX;
+ for (int i = 0; i < num_candidates; ++i) {
+ const float curr_score = candidate_keypoints[i].score_;
+
+ // Scores should be monotonically increasing.
+ SCHECK(last_score >= curr_score,
+ "Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
+ i - 1, last_score, i, curr_score, num_candidates);
+
+ last_score = curr_score;
+ }
+#endif
+}
+
+
+int KeypointDetector::SelectKeypointsInBox(
+ const BoundingBox& box,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ const int max_keypoints,
+ const int num_existing_keypoints,
+ const Keypoint* const existing_keypoints,
+ Keypoint* const final_keypoints) const {
+ if (max_keypoints <= 0) {
+ return 0;
+ }
+
+ // This is the distance within which keypoints may be placed to each other
+ // within this box, roughly based on the box dimensions.
+ const int distance =
+ MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
+
+ // First, mark keypoints that already happen to be inside this region. Ignore
+ // keypoints that are outside it, however close they might be.
+ interest_map_->Clear(false);
+ for (int i = 0; i < num_existing_keypoints; ++i) {
+ const Keypoint& candidate = existing_keypoints[i];
+
+ const int x_pos = candidate.pos_.x;
+ const int y_pos = candidate.pos_.y;
+ if (box.Contains(candidate.pos_)) {
+ MarkImage(x_pos, y_pos, distance, interest_map_.get());
+ }
+ }
+
+ // Now, go through and check which keypoints will still fit in the box.
+ int num_keypoints_selected = 0;
+ for (int i = 0; i < num_candidates; ++i) {
+ const Keypoint& candidate = candidate_keypoints[i];
+
+ const int x_pos = candidate.pos_.x;
+ const int y_pos = candidate.pos_.y;
+
+ if (!box.Contains(candidate.pos_) ||
+ !interest_map_->ValidPixel(x_pos, y_pos)) {
+ continue;
+ }
+
+ if (!(*interest_map_)[y_pos][x_pos]) {
+ final_keypoints[num_keypoints_selected++] = candidate;
+ if (num_keypoints_selected >= max_keypoints) {
+ break;
+ }
+ MarkImage(x_pos, y_pos, distance, interest_map_.get());
+ }
+ }
+ return num_keypoints_selected;
+}
+
+
+void KeypointDetector::SelectKeypoints(
+ const std::vector<BoundingBox>& boxes,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ FramePair* const curr_change) const {
+ // Now select all the interesting keypoints that fall insider our boxes.
+ curr_change->number_of_keypoints_ = 0;
+ for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
+ iter != boxes.end(); ++iter) {
+ const BoundingBox bounding_box = *iter;
+
+ // Count up keypoints that have already been selected, and fall within our
+ // box.
+ int num_keypoints_already_in_box = 0;
+ for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
+ if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
+ ++num_keypoints_already_in_box;
+ }
+ }
+
+ const int max_keypoints_to_find_in_box =
+ MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
+ kMaxKeypoints - curr_change->number_of_keypoints_);
+
+ const int num_new_keypoints_in_box = SelectKeypointsInBox(
+ bounding_box,
+ candidate_keypoints,
+ num_candidates,
+ max_keypoints_to_find_in_box,
+ curr_change->number_of_keypoints_,
+ curr_change->frame1_keypoints_,
+ curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
+
+ curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
+
+ LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
+ }
+}
+
+
+// Walks along the given circle checking for pixels above or below the center.
+// Returns a score, or 0 if the keypoint did not pass the criteria.
+//
+// Parameters:
+// circle_perimeter: the circumference in pixels of the circle.
+// threshold: the minimum number of contiguous pixels that must be above or
+// below the center value.
+// center_ptr: the location of the center pixel in memory
+// offsets: the relative offsets from the center pixel of the edge pixels.
+inline int TestCircle(const int circle_perimeter, const int threshold,
+ const uint8* const center_ptr,
+ const int* offsets) {
+ // Get the actual value of the center pixel for easier reference later on.
+ const int center_value = static_cast<int>(*center_ptr);
+
+ // Number of total pixels to check. Have to wrap around some in case
+ // the contiguous section is split by the array edges.
+ const int num_total = circle_perimeter + threshold - 1;
+
+ int num_above = 0;
+ int above_diff = 0;
+
+ int num_below = 0;
+ int below_diff = 0;
+
+ // Used to tell when this is definitely not going to meet the threshold so we
+ // can early abort.
+ int minimum_by_now = threshold - num_total + 1;
+
+ // Go through every pixel along the perimeter of the circle, and then around
+ // again a little bit.
+ for (int i = 0; i < num_total; ++i) {
+ // This should be faster than mod.
+ const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
+
+ // This gets the value of the current pixel along the perimeter by using
+ // a precomputed offset.
+ const int curr_value =
+ static_cast<int>(center_ptr[offsets[perim_index]]);
+
+ const int difference = curr_value - center_value;
+
+ if (difference > kFastDiffAmount) {
+ above_diff += difference;
+ ++num_above;
+
+ num_below = 0;
+ below_diff = 0;
+
+ if (num_above >= threshold) {
+ return above_diff;
+ }
+ } else if (difference < -kFastDiffAmount) {
+ below_diff += difference;
+ ++num_below;
+
+ num_above = 0;
+ above_diff = 0;
+
+ if (num_below >= threshold) {
+ return below_diff;
+ }
+ } else {
+ num_above = 0;
+ num_below = 0;
+ above_diff = 0;
+ below_diff = 0;
+ }
+
+ // See if there's any chance of making the threshold.
+ if (MAX(num_above, num_below) < minimum_by_now) {
+ // Didn't pass.
+ return 0;
+ }
+ ++minimum_by_now;
+ }
+
+ // Didn't pass.
+ return 0;
+}
+
+
+// Returns a score in the range [0.0, positive infinity) which represents the
+// relative likelihood of a point being a corner.
+float KeypointDetector::HarrisFilter(const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float x, const float y) const {
+ if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
+ I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
+ // Image gradient matrix.
+ float G[] = { 0, 0, 0, 0 };
+ CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
+
+ const float dx = G[0];
+ const float dy = G[3];
+ const float dxy = G[1];
+
+ // Harris-Nobel corner score.
+ return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
+ }
+
+ return 0.0f;
+}
+
+
+int KeypointDetector::AddExtraCandidatesForBoxes(
+ const std::vector<BoundingBox>& boxes,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) const {
+ int num_keypoints_added = 0;
+
+ for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
+ iter != boxes.end(); ++iter) {
+ const BoundingBox box = *iter;
+
+ for (int i = 0; i < kNumToAddAsCandidates; ++i) {
+ for (int j = 0; j < kNumToAddAsCandidates; ++j) {
+ if (num_keypoints_added >= max_num_keypoints) {
+ LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
+ return num_keypoints_added;
+ }
+
+ Keypoint curr_keypoint = keypoints[num_keypoints_added++];
+ curr_keypoint.pos_ = Point2f(
+ box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
+ box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
+ curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
+ }
+ }
+ }
+
+ return num_keypoints_added;
+}
+
+
+void KeypointDetector::FindKeypoints(const ImageData& image_data,
+ const std::vector<BoundingBox>& rois,
+ const FramePair& prev_change,
+ FramePair* const curr_change) {
+ // Copy keypoints from second frame of last pass to temp keypoints of this
+ // pass.
+ int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
+
+ const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
+ number_of_tmp_keypoints +=
+ FindFastKeypoints(image_data, max_num_fast,
+ tmp_keypoints_ + number_of_tmp_keypoints);
+
+ TimeLog("Found FAST keypoints");
+
+ if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
+ LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
+ kMaxTempKeypoints, number_of_tmp_keypoints);
+ }
+
+ if (kAddArbitraryKeypoints) {
+ // Add some for each object prior to scoring.
+ const int max_num_box_keypoints =
+ kMaxTempKeypoints - number_of_tmp_keypoints;
+ number_of_tmp_keypoints +=
+ AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
+ tmp_keypoints_ + number_of_tmp_keypoints);
+ TimeLog("Added box keypoints");
+
+ if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
+ LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
+ kMaxTempKeypoints, number_of_tmp_keypoints);
+ }
+ }
+
+ // Score them...
+ LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
+ ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
+ TimeLog("Scored keypoints");
+
+ // Now pare it down a bit.
+ SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
+ TimeLog("Sorted keypoints");
+
+ LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
+
+ SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
+ TimeLog("Selected keypoints");
+
+ LOGV("Picked %d (%d max) final keypoints out of %d potential.",
+ curr_change->number_of_keypoints_,
+ kMaxKeypoints, number_of_tmp_keypoints);
+}
+
+
+int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
+ Keypoint* const new_keypoints) {
+ int number_of_keypoints = 0;
+
+ // Caching values from last pass, just copy and compact.
+ for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
+ if (prev_change.optical_flow_found_keypoint_[i]) {
+ new_keypoints[number_of_keypoints] =
+ prev_change.frame2_keypoints_[i];
+
+ new_keypoints[number_of_keypoints].score_ =
+ prev_change.frame1_keypoints_[i].score_;
+
+ ++number_of_keypoints;
+ }
+ }
+
+ TimeLog("Copied keypoints");
+ return number_of_keypoints;
+}
+
+
+// FAST keypoint detector.
+int KeypointDetector::FindFastKeypoints(const Image<uint8>& frame,
+ const int quadrant,
+ const int downsample_factor,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) {
+ /*
+ // Reference for a circle of diameter 7.
+ const int circle[] = {0, 0, 1, 1, 1, 0, 0,
+ 0, 1, 0, 0, 0, 1, 0,
+ 1, 0, 0, 0, 0, 0, 1,
+ 1, 0, 0, 0, 0, 0, 1,
+ 1, 0, 0, 0, 0, 0, 1,
+ 0, 1, 0, 0, 0, 1, 0,
+ 0, 0, 1, 1, 1, 0, 0};
+ const int circle_offset[] =
+ {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
+ */
+
+ // Quick test of compass directions. Any length 16 circle with a break of up
+ // to 4 pixels will have at least 3 of these 4 pixels active.
+ static const int short_circle_perimeter = 4;
+ static const int short_threshold = 3;
+ static const int short_circle_x[] = { -3, 0, +3, 0 };
+ static const int short_circle_y[] = { 0, -3, 0, +3 };
+
+ // Precompute image offsets.
+ int short_offsets[short_circle_perimeter];
+ for (int i = 0; i < short_circle_perimeter; ++i) {
+ short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
+ }
+
+ // Large circle values.
+ static const int full_circle_perimeter = 16;
+ static const int full_threshold = 12;
+ static const int full_circle_x[] =
+ { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
+ static const int full_circle_y[] =
+ { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
+
+ // Precompute image offsets.
+ int full_offsets[full_circle_perimeter];
+ for (int i = 0; i < full_circle_perimeter; ++i) {
+ full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
+ }
+
+ const int scratch_stride = frame.stride();
+
+ keypoint_scratch_->Clear(0);
+
+ // Set up the bounds on the region to test based on the passed-in quadrant.
+ const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
+ const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
+ const int start_x =
+ kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
+ const int start_y =
+ kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
+ const int end_x = start_x + quadrant_width;
+ const int end_y = start_y + quadrant_height;
+
+ // Loop through once to find FAST keypoint clumps.
+ for (int img_y = start_y; img_y < end_y; ++img_y) {
+ const uint8* curr_pixel_ptr = frame[img_y] + start_x;
+
+ for (int img_x = start_x; img_x < end_x; ++img_x) {
+ // Only insert it if it meets the quick minimum requirements test.
+ if (TestCircle(short_circle_perimeter, short_threshold,
+ curr_pixel_ptr, short_offsets) != 0) {
+ // Longer test for actual keypoint score..
+ const int fast_score = TestCircle(full_circle_perimeter,
+ full_threshold,
+ curr_pixel_ptr,
+ full_offsets);
+
+ // Non-zero score means the keypoint was found.
+ if (fast_score != 0) {
+ uint8* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
+
+ // Increase the keypoint count on this pixel and the pixels in all
+ // 4 cardinal directions.
+ *center_ptr += 5;
+ *(center_ptr - 1) += 1;
+ *(center_ptr + 1) += 1;
+ *(center_ptr - scratch_stride) += 1;
+ *(center_ptr + scratch_stride) += 1;
+ }
+ }
+
+ ++curr_pixel_ptr;
+ } // x
+ } // y
+
+ TimeLog("Found FAST keypoints.");
+
+ int num_keypoints = 0;
+ // Loop through again and Harris filter pixels in the center of clumps.
+ // We can shrink the window by 1 pixel on every side.
+ for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
+ const uint8* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
+
+ for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
+ if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
+ Keypoint* const keypoint = keypoints + num_keypoints;
+ keypoint->pos_ = Point2f(
+ img_x * downsample_factor, img_y * downsample_factor);
+ keypoint->score_ = 0;
+ keypoint->type_ = KEYPOINT_TYPE_FAST;
+
+ ++num_keypoints;
+ if (num_keypoints >= max_num_keypoints) {
+ return num_keypoints;
+ }
+ }
+
+ ++curr_pixel_ptr;
+ } // x
+ } // y
+
+ TimeLog("Picked FAST keypoints.");
+
+ return num_keypoints;
+}
+
+int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) {
+ int downsample_factor = 1;
+ int num_found = 0;
+
+ // TODO(andrewharp): Get this working for multiple image scales.
+ for (int i = 0; i < 1; ++i) {
+ const Image<uint8>& frame = *image_data.GetPyramidSqrt2Level(i);
+ num_found += FindFastKeypoints(
+ frame, fast_quadrant_,
+ downsample_factor, max_num_keypoints, keypoints + num_found);
+ downsample_factor *= 2;
+ }
+
+ // Increment the current quadrant.
+ fast_quadrant_ = (fast_quadrant_ + 1) % 4;
+
+ return num_found;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
new file mode 100644
index 0000000000..6cdd5dde11
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
@@ -0,0 +1,133 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+struct Keypoint;
+
+class KeypointDetector {
+ public:
+ explicit KeypointDetector(const KeypointDetectorConfig* const config)
+ : config_(config),
+ keypoint_scratch_(new Image<uint8>(config_->image_size)),
+ interest_map_(new Image<bool>(config_->image_size)),
+ fast_quadrant_(0) {
+ interest_map_->Clear(false);
+ }
+
+ ~KeypointDetector() {}
+
+ // Finds a new set of keypoints for the current frame, picked from the current
+ // set of keypoints and also from a set discovered via a keypoint detector.
+ // Special attention is applied to make sure that keypoints are distributed
+ // within the supplied ROIs.
+ void FindKeypoints(const ImageData& image_data,
+ const std::vector<BoundingBox>& rois,
+ const FramePair& prev_change,
+ FramePair* const curr_change);
+
+ private:
+ // Compute the corneriness of a point in the image.
+ float HarrisFilter(const Image<int32>& I_x, const Image<int32>& I_y,
+ const float x, const float y) const;
+
+ // Adds a grid of candidate keypoints to the given box, up to
+ // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
+ int AddExtraCandidatesForBoxes(
+ const std::vector<BoundingBox>& boxes,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) const;
+
+ // Scan the frame for potential keypoints using the FAST keypoint detector.
+ // Quadrant is an argument 0-3 which refers to the quadrant of the image in
+ // which to detect keypoints.
+ int FindFastKeypoints(const Image<uint8>& frame,
+ const int quadrant,
+ const int downsample_factor,
+ const int max_num_keypoints,
+ Keypoint* const keypoints);
+
+ int FindFastKeypoints(const ImageData& image_data,
+ const int max_num_keypoints,
+ Keypoint* const keypoints);
+
+ // Score a bunch of candidate keypoints. Assigns the scores to the input
+ // candidate_keypoints array entries.
+ void ScoreKeypoints(const ImageData& image_data,
+ const int num_candidates,
+ Keypoint* const candidate_keypoints);
+
+ void SortKeypoints(const int num_candidates,
+ Keypoint* const candidate_keypoints) const;
+
+ // Selects a set of keypoints falling within the supplied box such that the
+ // most highly rated keypoints are picked first, and so that none of them are
+ // too close together.
+ int SelectKeypointsInBox(
+ const BoundingBox& box,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ const int max_keypoints,
+ const int num_existing_keypoints,
+ const Keypoint* const existing_keypoints,
+ Keypoint* const final_keypoints) const;
+
+ // Selects from the supplied sorted keypoint pool a set of keypoints that will
+ // best cover the given set of boxes, such that each box is covered at a
+ // resolution proportional to its size.
+ void SelectKeypoints(
+ const std::vector<BoundingBox>& boxes,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ FramePair* const frame_change) const;
+
+ // Copies and compacts the found keypoints in the second frame of prev_change
+ // into the array at new_keypoints.
+ static int CopyKeypoints(const FramePair& prev_change,
+ Keypoint* const new_keypoints);
+
+ const KeypointDetectorConfig* const config_;
+
+ // Scratch memory for keypoint candidacy detection and non-max suppression.
+ std::unique_ptr<Image<uint8> > keypoint_scratch_;
+
+ // Regions of the image to pay special attention to.
+ std::unique_ptr<Image<bool> > interest_map_;
+
+ // The current quadrant of the image to detect FAST keypoints in.
+ // Keypoint detection is staggered for performance reasons. Every four frames
+ // a full scan of the frame will have been performed.
+ int fast_quadrant_;
+
+ Keypoint tmp_keypoints_[kMaxTempKeypoints];
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/log_streaming.h b/tensorflow/examples/android/jni/object_tracking/log_streaming.h
new file mode 100644
index 0000000000..e68945cc72
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/log_streaming.h
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+
+#include <string.h>
+#include <string>
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+#define LOGV(...)
+#define LOGD(...)
+#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.cc b/tensorflow/examples/android/jni/object_tracking/object_detector.cc
new file mode 100644
index 0000000000..7f65716fdf
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_detector.cc
@@ -0,0 +1,27 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+
+namespace tf_tracking {
+
+// This is here so that the vtable gets created properly.
+ObjectDetectorBase::~ObjectDetectorBase() {}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h
new file mode 100644
index 0000000000..043f606e1d
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h
@@ -0,0 +1,232 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+// Defines the ObjectDetector class that is the main interface for detecting
+// ObjectModelBases in frames.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+
+#include <float.h>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
+
+namespace tf_tracking {
+
+// Adds BoundingSquares to a vector such that the first square added is centered
+// in the position given and of square_size, and the remaining squares are added
+// concentrentically, scaling down by scale_factor until the minimum threshold
+// size is passed.
+// Squares that do not fall completely within image_bounds will not be added.
+static inline void FillWithSquares(
+ const BoundingBox& image_bounds,
+ const BoundingBox& position,
+ const float starting_square_size,
+ const float smallest_square_size,
+ const float scale_factor,
+ std::vector<BoundingSquare>* const squares) {
+ BoundingSquare descriptor_area =
+ GetCenteredSquare(position, starting_square_size);
+
+ SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
+
+ // Use a do/while loop to ensure that at least one descriptor is created.
+ do {
+ if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
+ squares->push_back(descriptor_area);
+ }
+ descriptor_area.Scale(scale_factor);
+ } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
+ LOGV("Created %zu squares starting from size %.2f to min size %.2f "
+ "using scale factor: %.2f",
+ squares->size(), starting_square_size, smallest_square_size,
+ scale_factor);
+}
+
+
+// Represents a potential detection of a specific ObjectExemplar and Descriptor
+// at a specific position in the image.
+class Detection {
+ public:
+ explicit Detection(const ObjectModelBase* const object_model,
+ const MatchScore match_score,
+ const BoundingBox& bounding_box)
+ : object_model_(object_model),
+ match_score_(match_score),
+ bounding_box_(bounding_box) {}
+
+ Detection(const Detection& other)
+ : object_model_(other.object_model_),
+ match_score_(other.match_score_),
+ bounding_box_(other.bounding_box_) {}
+
+ virtual ~Detection() {}
+
+ inline BoundingBox GetObjectBoundingBox() const {
+ return bounding_box_;
+ }
+
+ inline MatchScore GetMatchScore() const {
+ return match_score_;
+ }
+
+ inline const ObjectModelBase* GetObjectModel() const {
+ return object_model_;
+ }
+
+ inline bool Intersects(const Detection& other) {
+ // Check if any of the four axes separates us, there must be at least one.
+ return bounding_box_.Intersects(other.bounding_box_);
+ }
+
+ struct Comp {
+ inline bool operator()(const Detection& a, const Detection& b) const {
+ return a.match_score_ > b.match_score_;
+ }
+ };
+
+ // TODO(andrewharp): add accessors to update these instead.
+ const ObjectModelBase* object_model_;
+ MatchScore match_score_;
+ BoundingBox bounding_box_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const Detection& detection) {
+ const BoundingBox actual_area = detection.GetObjectBoundingBox();
+ stream << actual_area;
+ return stream;
+}
+
+class ObjectDetectorBase {
+ public:
+ explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
+ : config_(config),
+ image_data_(NULL) {}
+
+ virtual ~ObjectDetectorBase();
+
+ // Sets the current image data. All calls to ObjectDetector other than
+ // FillDescriptors use the image data last set.
+ inline void SetImageData(const ImageData* const image_data) {
+ image_data_ = image_data;
+ }
+
+ // Main entry point into the detection algorithm.
+ // Scans the frame for candidates, tweaks them, and fills in the
+ // given std::vector of Detection objects with acceptable matches.
+ virtual void Detect(const std::vector<BoundingSquare>& positions,
+ std::vector<Detection>* const detections) const = 0;
+
+ virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
+
+ virtual void DeleteObjectModel(const std::string& name) = 0;
+
+ virtual void GetObjectModels(
+ std::vector<const ObjectModelBase*>* models) const = 0;
+
+ // Creates a new ObjectExemplar from the given position in the context of
+ // the last frame passed to NextFrame.
+ // Will return null in the case that there's no room for a descriptor to be
+ // created in the example area, or the example area is not completely
+ // contained within the frame.
+ virtual void UpdateModel(
+ const Image<uint8>& base_image,
+ const IntegralImage& integral_image,
+ const BoundingBox& bounding_box,
+ const bool locked,
+ ObjectModelBase* model) const = 0;
+
+ virtual void Draw() const = 0;
+
+ virtual bool AllowSpontaneousDetections() = 0;
+
+ protected:
+ const std::unique_ptr<const ObjectDetectorConfig> config_;
+
+ // The latest frame data, upon which all detections will be performed.
+ // Not owned by this object, just provided for reference by ObjectTracker
+ // via SetImageData().
+ const ImageData* image_data_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
+};
+
+template <typename ModelType>
+class ObjectDetector : public ObjectDetectorBase {
+ public:
+ explicit ObjectDetector(const ObjectDetectorConfig* const config)
+ : ObjectDetectorBase(config) {}
+
+ virtual ~ObjectDetector() {
+ typename std::map<std::string, ModelType*>::const_iterator it =
+ object_models_.begin();
+ for (; it != object_models_.end(); ++it) {
+ ModelType* model = it->second;
+ delete model;
+ }
+ }
+
+ virtual void DeleteObjectModel(const std::string& name) {
+ ModelType* model = object_models_[name];
+ CHECK_ALWAYS(model != NULL, "Model was null!");
+ object_models_.erase(name);
+ SAFE_DELETE(model);
+ }
+
+ virtual void GetObjectModels(
+ std::vector<const ObjectModelBase*>* models) const {
+ typename std::map<std::string, ModelType*>::const_iterator it =
+ object_models_.begin();
+ for (; it != object_models_.end(); ++it) {
+ models->push_back(it->second);
+ }
+ }
+
+ virtual bool AllowSpontaneousDetections() {
+ return false;
+ }
+
+ protected:
+ std::map<std::string, ModelType*> object_models_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h
new file mode 100644
index 0000000000..2d359668b2
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_model.h
@@ -0,0 +1,101 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+// Contains ObjectModelBase declaration.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+
+#ifdef __RENDER_OPENGL__
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+#endif
+
+#include <vector>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+namespace tf_tracking {
+
+// The ObjectModelBase class represents all the known appearance information for
+// an object. It is not a specific instance of the object in the world,
+// but just the general appearance information that enables detection. An
+// ObjectModelBase can be reused across multiple-instances of TrackedObjects.
+class ObjectModelBase {
+ public:
+ ObjectModelBase(const std::string& name) : name_(name) {}
+
+ virtual ~ObjectModelBase() {}
+
+ // Called when the next step in an ongoing track occurs.
+ virtual void TrackStep(
+ const BoundingBox& position, const Image<uint8>& image,
+ const IntegralImage& integral_image, const bool authoritative) {}
+
+ // Called when an object track is lost.
+ virtual void TrackLost() {}
+
+ // Called when an object track is confirmed as legitimate.
+ virtual void TrackConfirmed() {}
+
+ virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0;
+
+ virtual MatchScore GetMatchScore(
+ const BoundingBox& position, const ImageData& image_data) const = 0;
+
+ virtual void Draw(float* const depth) const = 0;
+
+ inline const std::string& GetName() const {
+ return name_;
+ }
+
+ protected:
+ const std::string name_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase);
+};
+
+template <typename DetectorType>
+class ObjectModel : public ObjectModelBase {
+ public:
+ ObjectModel<DetectorType>(const DetectorType* const detector,
+ const std::string& name)
+ : ObjectModelBase(name), detector_(detector) {}
+
+ protected:
+ const DetectorType* const detector_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc
new file mode 100644
index 0000000000..1d867b934b
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc
@@ -0,0 +1,690 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef __RENDER_OPENGL__
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+#endif
+
+#include <string>
+#include <map>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+ObjectTracker::ObjectTracker(const TrackerConfig* const config,
+ ObjectDetectorBase* const detector)
+ : config_(config),
+ frame_width_(config->image_size.width),
+ frame_height_(config->image_size.height),
+ curr_time_(0),
+ num_frames_(0),
+ flow_cache_(&config->flow_config),
+ keypoint_detector_(&config->keypoint_detector_config),
+ curr_num_frame_pairs_(0),
+ first_frame_index_(0),
+ frame1_(new ImageData(frame_width_, frame_height_)),
+ frame2_(new ImageData(frame_width_, frame_height_)),
+ detector_(detector),
+ num_detected_(0) {
+ for (int i = 0; i < kNumFrames; ++i) {
+ frame_pairs_[i].Init(-1, -1);
+ }
+}
+
+
+ObjectTracker::~ObjectTracker() {
+ for (TrackedObjectMap::iterator iter = objects_.begin();
+ iter != objects_.end(); iter++) {
+ TrackedObject* object = iter->second;
+ SAFE_DELETE(object);
+ }
+}
+
+
+// Finds the correspondences for all the points in the current pair of frames.
+// Stores the results in the given FramePair.
+void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
+ // Keypoints aren't found until they're found.
+ memset(frame_pair->optical_flow_found_keypoint_, false,
+ sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
+ TimeLog("Cleared old found keypoints");
+
+ int num_keypoints_found = 0;
+
+ // For every keypoint...
+ for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
+ Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
+ Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
+
+ if (flow_cache_.FindNewPositionOfPoint(
+ keypoint1->pos_.x, keypoint1->pos_.y,
+ &keypoint2->pos_.x, &keypoint2->pos_.y)) {
+ frame_pair->optical_flow_found_keypoint_[i_feat] = true;
+ ++num_keypoints_found;
+ }
+ }
+
+ TimeLog("Found correspondences");
+
+ LOGV("Found %d of %d keypoint correspondences",
+ num_keypoints_found, frame_pair->number_of_keypoints_);
+}
+
+
+void ObjectTracker::NextFrame(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3) {
+ IncrementFrameIndex();
+ LOGV("Received frame %d", num_frames_);
+
+ FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
+ curr_change->Init(curr_time_, timestamp);
+
+ CHECK_ALWAYS(curr_time_ < timestamp,
+ "Timestamp must monotonically increase! Went from %lld to %lld"
+ " on frame %d.",
+ curr_time_, timestamp, num_frames_);
+ curr_time_ = timestamp;
+
+ // Swap the frames.
+ frame1_.swap(frame2_);
+
+ frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
+
+ if (detector_.get() != NULL) {
+ detector_->SetImageData(frame2_.get());
+ }
+
+ flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
+
+ if (num_frames_ == 1) {
+ // This must be the first frame, so abort.
+ return;
+ }
+
+ if (config_->always_track || objects_.size() > 0) {
+ LOGV("Tracking %zu targets", objects_.size());
+ ComputeKeypoints(true);
+ TimeLog("Keypoints computed!");
+
+ FindCorrespondences(curr_change);
+ TimeLog("Flow computed!");
+
+ TrackObjects();
+ }
+ TimeLog("Targets tracked!");
+
+ if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
+ DetectTargets();
+ }
+ TimeLog("Detected objects.");
+}
+
+
+TrackedObject* ObjectTracker::MaybeAddObject(
+ const std::string& id,
+ const Image<uint8>& source_image,
+ const BoundingBox& bounding_box,
+ const ObjectModelBase* object_model) {
+ // Train the detector if this is a new object.
+ if (objects_.find(id) != objects_.end()) {
+ return objects_[id];
+ }
+
+ // Need to get a non-const version of the model, or create a new one if it
+ // wasn't given.
+ ObjectModelBase* model = NULL;
+ if (detector_ != NULL) {
+ // If a detector is registered, then this new object must have a model.
+ CHECK_ALWAYS(object_model != NULL, "No model given!");
+ model = detector_->CreateObjectModel(object_model->GetName());
+ }
+ TrackedObject* const object =
+ new TrackedObject(id, source_image, bounding_box, model);
+
+ objects_[id] = object;
+ return object;
+}
+
+
+void ObjectTracker::RegisterNewObjectWithAppearance(
+ const std::string& id, const uint8* const new_frame,
+ const BoundingBox& bounding_box) {
+ ObjectModelBase* object_model = NULL;
+
+ Image<uint8> image(frame_width_, frame_height_);
+ image.FromArray(new_frame, frame_width_, 1);
+
+ if (detector_ != NULL) {
+ object_model = detector_->CreateObjectModel(id);
+ CHECK_ALWAYS(object_model != NULL, "Null object model!");
+
+ const IntegralImage integral_image(image);
+ object_model->TrackStep(bounding_box, image, integral_image, true);
+ }
+
+ // Create an object at this position.
+ CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
+ if (objects_.find(id) == objects_.end()) {
+ TrackedObject* const object =
+ MaybeAddObject(id, image, bounding_box, object_model);
+ CHECK_ALWAYS(object != NULL, "Object not created!");
+ }
+}
+
+
+void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box,
+ const int64 timestamp) {
+ CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp);
+ CHECK_ALWAYS(timestamp <= curr_time_,
+ "Timestamp too great! %lld vs %lld", timestamp, curr_time_);
+
+ TrackedObject* const object = GetObject(id);
+
+ // Track this bounding box from the past to the current time.
+ const BoundingBox current_position = TrackBox(bounding_box, timestamp);
+
+ object->UpdatePosition(current_position, curr_time_, *frame2_, false);
+
+ VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
+ << std::endl;
+}
+
+
+void ObjectTracker::SetCurrentPositionOfObject(
+ const std::string& id, const BoundingBox& bounding_box) {
+ SetPreviousPositionOfObject(id, bounding_box, curr_time_);
+}
+
+
+void ObjectTracker::ForgetTarget(const std::string& id) {
+ LOGV("Forgetting object %s", id.c_str());
+ TrackedObject* const object = GetObject(id);
+ delete object;
+ objects_.erase(id);
+
+ if (detector_ != NULL) {
+ detector_->DeleteObjectModel(id);
+ }
+}
+
+
+int ObjectTracker::GetKeypointsPacked(uint16* const out_data,
+ const float scale) const {
+ const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
+ uint16* curr_data = out_data;
+ int num_keypoints = 0;
+
+ for (int i = 0; i < change.number_of_keypoints_; ++i) {
+ if (change.optical_flow_found_keypoint_[i]) {
+ ++num_keypoints;
+ const Point2f& point1 = change.frame1_keypoints_[i].pos_;
+ *curr_data++ = RealToFixed115(point1.x * scale);
+ *curr_data++ = RealToFixed115(point1.y * scale);
+
+ const Point2f& point2 = change.frame2_keypoints_[i].pos_;
+ *curr_data++ = RealToFixed115(point2.x * scale);
+ *curr_data++ = RealToFixed115(point2.y * scale);
+ }
+ }
+
+ return num_keypoints;
+}
+
+
+int ObjectTracker::GetKeypoints(const bool only_found,
+ float* const out_data) const {
+ int curr_keypoint = 0;
+ const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
+
+ for (int i = 0; i < change.number_of_keypoints_; ++i) {
+ if (!only_found || change.optical_flow_found_keypoint_[i]) {
+ const int base = curr_keypoint * kKeypointStep;
+ out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
+ out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
+
+ out_data[base + 2] =
+ change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
+ out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
+ out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
+
+ out_data[base + 5] = change.frame1_keypoints_[i].score_;
+ out_data[base + 6] = change.frame1_keypoints_[i].type_;
+ ++curr_keypoint;
+ }
+ }
+
+ LOGV("Got %d keypoints.", curr_keypoint);
+
+ return curr_keypoint;
+}
+
+
+BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
+ const FramePair& frame_pair) const {
+ float translation_x;
+ float translation_y;
+
+ float scale_x;
+ float scale_y;
+
+ BoundingBox tracked_box(region);
+ frame_pair.AdjustBox(
+ tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
+
+ tracked_box.Shift(Point2f(translation_x, translation_y));
+
+ if (scale_x > 0 && scale_y > 0) {
+ tracked_box.Scale(scale_x, scale_y);
+ }
+ return tracked_box;
+}
+
+
+BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
+ const int64 timestamp) const {
+ CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp);
+ CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
+
+ // Anything that ended before the requested timestamp is of no concern to us.
+ bool found_it = false;
+ int num_frames_back = -1;
+ for (int i = 0; i < curr_num_frame_pairs_; ++i) {
+ const FramePair& frame_pair =
+ frame_pairs_[GetNthIndexFromEnd(i)];
+
+ if (frame_pair.end_time_ <= timestamp) {
+ num_frames_back = i - 1;
+
+ if (num_frames_back > 0) {
+ LOGV("Went %d out of %d frames before finding frame. (index: %d)",
+ num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
+ }
+
+ found_it = true;
+ break;
+ }
+ }
+
+ if (!found_it) {
+ LOGW("History did not go back far enough! %lld vs %lld",
+ frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
+ frame_pairs_[GetNthIndexFromStart(0)].end_time_,
+ frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
+ }
+
+ // Loop over all the frames in the queue, tracking the accumulated delta
+ // of the point from frame to frame. It's possible the point could
+ // go out of frame, but keep tracking as best we can, using points near
+ // the edge of the screen where it went out of bounds.
+ BoundingBox tracked_box(region);
+ for (int i = num_frames_back; i >= 0; --i) {
+ const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
+ SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
+ tracked_box = TrackBox(tracked_box, frame_pair);
+ }
+ return tracked_box;
+}
+
+
+// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
+// 3d transformation matrix.
+inline void Convert3x3To4x4(
+ const float* const in_matrix, float* const out_matrix) {
+ // X
+ out_matrix[0] = in_matrix[0];
+ out_matrix[1] = in_matrix[3];
+ out_matrix[2] = 0.0f;
+ out_matrix[3] = 0.0f;
+
+ // Y
+ out_matrix[4] = in_matrix[1];
+ out_matrix[5] = in_matrix[4];
+ out_matrix[6] = 0.0f;
+ out_matrix[7] = 0.0f;
+
+ // Z
+ out_matrix[8] = 0.0f;
+ out_matrix[9] = 0.0f;
+ out_matrix[10] = 1.0f;
+ out_matrix[11] = 0.0f;
+
+ // Translation
+ out_matrix[12] = in_matrix[2];
+ out_matrix[13] = in_matrix[5];
+ out_matrix[14] = 0.0f;
+ out_matrix[15] = 1.0f;
+}
+
+
+void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
+ const float* const frame_to_canvas) const {
+#ifdef __RENDER_OPENGL__
+ glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
+
+ glMatrixMode(GL_PROJECTION);
+ glLoadIdentity();
+
+ glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
+
+ // To make Y go the right direction (0 at top of frame).
+ glScalef(1.0f, -1.0f, 1.0f);
+ glTranslatef(0.0f, -canvas_height, 0.0f);
+
+ glMatrixMode(GL_MODELVIEW);
+ glLoadIdentity();
+
+ glPushMatrix();
+
+ // Apply the frame to canvas transformation.
+ static GLfloat transformation[16];
+ Convert3x3To4x4(frame_to_canvas, transformation);
+ glMultMatrixf(transformation);
+
+ // Draw tracked object bounding boxes.
+ for (TrackedObjectMap::const_iterator iter = objects_.begin();
+ iter != objects_.end(); ++iter) {
+ TrackedObject* tracked_object = iter->second;
+ tracked_object->Draw();
+ }
+
+ static const bool kRenderDebugPyramid = false;
+ if (kRenderDebugPyramid) {
+ glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
+ for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
+ Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
+ }
+ }
+
+ static const bool kRenderDebugDerivative = false;
+ if (kRenderDebugDerivative) {
+ glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ const Image<int32>& dx = *frame1_->GetSpatialX(i);
+ Image<uint8> render_image(dx.GetWidth(), dx.GetHeight());
+ for (int y = 0; y < dx.GetHeight(); ++y) {
+ const int32* dx_ptr = dx[y];
+ uint8* dst_ptr = render_image[y];
+ for (int x = 0; x < dx.GetWidth(); ++x) {
+ *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
+ }
+ }
+
+ Sprite(render_image).Draw();
+ }
+ }
+
+ if (detector_ != NULL) {
+ glDisable(GL_CULL_FACE);
+ detector_->Draw();
+ }
+ glPopMatrix();
+#endif
+}
+
+static void AddQuadrants(const BoundingBox& box,
+ std::vector<BoundingBox>* boxes) {
+ const Point2f center = box.GetCenter();
+
+ float x1 = box.left_;
+ float x2 = center.x;
+ float x3 = box.right_;
+
+ float y1 = box.top_;
+ float y2 = center.y;
+ float y3 = box.bottom_;
+
+ // Upper left.
+ boxes->push_back(BoundingBox(x1, y1, x2, y2));
+
+ // Upper right.
+ boxes->push_back(BoundingBox(x2, y1, x3, y2));
+
+ // Bottom left.
+ boxes->push_back(BoundingBox(x1, y2, x2, y3));
+
+ // Bottom right.
+ boxes->push_back(BoundingBox(x2, y2, x3, y3));
+
+ // Whole thing.
+ boxes->push_back(box);
+}
+
+void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
+ const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
+ FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
+
+ std::vector<BoundingBox> boxes;
+
+ for (TrackedObjectMap::iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ BoundingBox box = object_iter->second->GetPosition();
+ box.Scale(config_->object_box_scale_factor_for_features,
+ config_->object_box_scale_factor_for_features);
+ AddQuadrants(box, &boxes);
+ }
+
+ AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
+
+ keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
+}
+
+
+// Given a vector of detections and a model, simply returns the Detection for
+// that model with the highest correlation.
+bool ObjectTracker::GetBestObjectForDetection(
+ const Detection& detection, TrackedObject** match) const {
+ TrackedObject* best_match = NULL;
+ float best_overlap = -FLT_MAX;
+
+ LOGV("Looking for matches in %zu objects!", objects_.size());
+ for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ TrackedObject* const tracked_object = object_iter->second;
+
+ const float overlap = tracked_object->GetPosition().PascalScore(
+ detection.GetObjectBoundingBox());
+
+ if (!detector_->AllowSpontaneousDetections() &&
+ (detection.GetObjectModel() != tracked_object->GetModel())) {
+ if (overlap > 0.0f) {
+ return false;
+ }
+ continue;
+ }
+
+ const float jump_distance =
+ (tracked_object->GetPosition().GetCenter() -
+ detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
+
+ const float allowed_distance =
+ tracked_object->GetAllowableDistanceSquared();
+
+ LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
+ jump_distance, allowed_distance, overlap);
+
+ // TODO(andrewharp): No need to do this verification twice, eliminate
+ // one of the score checks (the other being in OnDetection).
+ if (jump_distance < allowed_distance &&
+ overlap > best_overlap &&
+ tracked_object->GetMatchScore() + kMatchScoreBuffer <
+ detection.GetMatchScore()) {
+ best_match = tracked_object;
+ best_overlap = overlap;
+ } else if (overlap > 0.0f) {
+ return false;
+ }
+ }
+
+ *match = best_match;
+ return true;
+}
+
+
+void ObjectTracker::ProcessDetections(
+ std::vector<Detection>* const detections) {
+ LOGV("Initial detection done, iterating over %zu detections now.",
+ detections->size());
+
+ const bool spontaneous_detections_allowed =
+ detector_->AllowSpontaneousDetections();
+ for (std::vector<Detection>::const_iterator it = detections->begin();
+ it != detections->end(); ++it) {
+ const Detection& detection = *it;
+ SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
+ "Frame does not contain bounding box!");
+
+ TrackedObject* best_match = NULL;
+
+ const bool no_collisions =
+ GetBestObjectForDetection(detection, &best_match);
+
+ // Need to get a non-const version of the model, or create a new one if it
+ // wasn't given.
+ ObjectModelBase* model =
+ const_cast<ObjectModelBase*>(detection.GetObjectModel());
+
+ if (best_match != NULL) {
+ if (model != best_match->GetModel()) {
+ CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
+ "Model for object changed but spontaneous detections not allowed!");
+ }
+ best_match->OnDetection(model,
+ detection.GetObjectBoundingBox(),
+ detection.GetMatchScore(),
+ curr_time_, *frame2_);
+ } else if (no_collisions && spontaneous_detections_allowed) {
+ if (detection.GetMatchScore() > kMinimumMatchScore) {
+ LOGV("No match, adding it!");
+ const ObjectModelBase* model = detection.GetObjectModel();
+ std::ostringstream ss;
+ // TODO(andrewharp): Generate this in a more general fashion.
+ ss << "hand_" << num_detected_++;
+ std::string object_name = ss.str();
+ MaybeAddObject(object_name, *frame2_->GetImage(),
+ detection.GetObjectBoundingBox(), model);
+ }
+ }
+ }
+}
+
+
+void ObjectTracker::DetectTargets() {
+ // Detect all object model types that we're currently tracking.
+ std::vector<const ObjectModelBase*> object_models;
+ detector_->GetObjectModels(&object_models);
+ if (object_models.size() == 0) {
+ LOGV("No objects to search for, aborting.");
+ return;
+ }
+
+ LOGV("Trying to detect %zu models", object_models.size());
+
+ LOGV("Creating test vector!");
+ std::vector<BoundingSquare> positions;
+
+ for (TrackedObjectMap::iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ TrackedObject* const tracked_object = object_iter->second;
+
+#if DEBUG_PREDATOR
+ positions.push_back(GetCenteredSquare(
+ frame2_->GetImage()->GetContainingBox(), 32.0f));
+#else
+ const BoundingBox& position = tracked_object->GetPosition();
+
+ const float square_size = MAX(
+ kScanMinSquareSize / (kLastKnownPositionScaleFactor *
+ kLastKnownPositionScaleFactor),
+ MIN(position.GetWidth(),
+ position.GetHeight())) / kLastKnownPositionScaleFactor;
+
+ FillWithSquares(frame2_->GetImage()->GetContainingBox(),
+ tracked_object->GetPosition(),
+ square_size,
+ kScanMinSquareSize,
+ kLastKnownPositionScaleFactor,
+ &positions);
+ }
+#endif
+
+ LOGV("Created test vector!");
+
+ std::vector<Detection> detections;
+ LOGV("Detecting!");
+ detector_->Detect(positions, &detections);
+ LOGV("Found %zu detections", detections.size());
+
+ TimeLog("Finished detection.");
+
+ ProcessDetections(&detections);
+
+ TimeLog("iterated over detections");
+
+ LOGV("Done detecting!");
+}
+
+
+void ObjectTracker::TrackObjects() {
+ // TODO(andrewharp): Correlation should be allowed to remove objects too.
+ const bool automatic_removal_allowed = detector_.get() != NULL ?
+ detector_->AllowSpontaneousDetections() : false;
+
+ LOGV("Tracking %zu objects!", objects_.size());
+ std::vector<std::string> dead_objects;
+ for (TrackedObjectMap::iterator iter = objects_.begin();
+ iter != objects_.end(); iter++) {
+ TrackedObject* object = iter->second;
+ const BoundingBox tracked_position = TrackBox(
+ object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
+ object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
+
+ if (automatic_removal_allowed &&
+ object->GetNumConsecutiveFramesBelowThreshold() >
+ kMaxNumDetectionFailures * 5) {
+ dead_objects.push_back(iter->first);
+ }
+ }
+
+ if (detector_ != NULL && automatic_removal_allowed) {
+ for (std::vector<std::string>::iterator iter = dead_objects.begin();
+ iter != dead_objects.end(); iter++) {
+ LOGE("Removing object! %s", iter->c_str());
+ ForgetTarget(*iter);
+ }
+ }
+ TimeLog("Tracked all objects.");
+
+ LOGV("%zu objects tracked!", objects_.size());
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
new file mode 100644
index 0000000000..3d2a9af360
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
@@ -0,0 +1,271 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TrackedObjectMap& map) {
+ for (TrackedObjectMap::const_iterator iter = map.begin();
+ iter != map.end(); ++iter) {
+ const TrackedObject& tracked_object = *iter->second;
+ const std::string& key = iter->first;
+ stream << key << ": " << tracked_object;
+ }
+ return stream;
+}
+
+
+// ObjectTracker is the highest-level class in the tracking/detection framework.
+// It handles basic image processing, keypoint detection, keypoint tracking,
+// object tracking, and object detection/relocalization.
+class ObjectTracker {
+ public:
+ ObjectTracker(const TrackerConfig* const config,
+ ObjectDetectorBase* const detector);
+ virtual ~ObjectTracker();
+
+ virtual void NextFrame(const uint8* const new_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3) {
+ NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
+ }
+
+ // Called upon the arrival of a new frame of raw data.
+ // Does all image processing, keypoint detection, and object
+ // tracking/detection for registered objects.
+ // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
+ // represents the main transformation that has happened between the last
+ // and the current frame.
+ // Argument align_level is the pyramid level (where 0 == finest) that
+ // the matrix is valid for.
+ virtual void NextFrame(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3);
+
+ virtual void RegisterNewObjectWithAppearance(
+ const std::string& id, const uint8* const new_frame,
+ const BoundingBox& bounding_box);
+
+ // Updates the position of a tracked object, given that it was known to be at
+ // a certain position at some point in the past.
+ virtual void SetPreviousPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box,
+ const int64 timestamp);
+
+ // Sets the current position of the object in the most recent frame provided.
+ virtual void SetCurrentPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box);
+
+ // Tells the ObjectTracker to stop tracking a target.
+ void ForgetTarget(const std::string& id);
+
+ // Fills the given out_data buffer with the latest detected keypoint
+ // correspondences, first scaled by scale_factor (to adjust for downsampling
+ // that may have occurred elsewhere), then packed in a fixed-point format.
+ int GetKeypointsPacked(uint16* const out_data,
+ const float scale_factor) const;
+
+ // Copy the keypoint arrays after computeFlow is called.
+ // out_data should be at least kMaxKeypoints * kKeypointStep long.
+ // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
+ // where N is the number of keypoints tracked. N is returned as the result.
+ int GetKeypoints(const bool only_found, float* const out_data) const;
+
+ // Returns the current position of a box, given that it was at a certain
+ // position at the given time.
+ BoundingBox TrackBox(const BoundingBox& region,
+ const int64 timestamp) const;
+
+ // Returns the number of frames that have been passed to NextFrame().
+ inline int GetNumFrames() const {
+ return num_frames_;
+ }
+
+ inline bool HaveObject(const std::string& id) const {
+ return objects_.find(id) != objects_.end();
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline const TrackedObject* GetObject(const std::string& id) const {
+ TrackedObjectMap::const_iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline TrackedObject* GetObject(const std::string& id) {
+ TrackedObjectMap::iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ bool IsObjectVisible(const std::string& id) const {
+ SCHECK(HaveObject(id), "Don't have this object.");
+
+ const TrackedObject* object = GetObject(id);
+ return object->IsVisible();
+ }
+
+ virtual void Draw(const int canvas_width, const int canvas_height,
+ const float* const frame_to_canvas) const;
+
+ protected:
+ // Creates a new tracked object at the given position.
+ // If an object model is provided, then that model will be associated with the
+ // object. If not, a new model may be created from the appearance at the
+ // initial position and registered with the object detector.
+ virtual TrackedObject* MaybeAddObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ const ObjectModelBase* object_model);
+
+ // Find the keypoints in the frame before the current frame.
+ // If only one frame exists, keypoints will be found in that frame.
+ void ComputeKeypoints(const bool cached_ok = false);
+
+ // Finds the correspondences for all the points in the current pair of frames.
+ // Stores the results in the given FramePair.
+ void FindCorrespondences(FramePair* const curr_change) const;
+
+ inline int GetNthIndexFromEnd(const int offset) const {
+ return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
+ }
+
+ BoundingBox TrackBox(const BoundingBox& region,
+ const FramePair& frame_pair) const;
+
+ inline void IncrementFrameIndex() {
+ // Move the current framechange index up.
+ ++num_frames_;
+ ++curr_num_frame_pairs_;
+
+ // If we've got too many, push up the start of the queue.
+ if (curr_num_frame_pairs_ > kNumFrames) {
+ first_frame_index_ = GetNthIndexFromStart(1);
+ --curr_num_frame_pairs_;
+ }
+ }
+
+ inline int GetNthIndexFromStart(const int offset) const {
+ SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
+ "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
+ return (first_frame_index_ + offset) % kNumFrames;
+ }
+
+ void TrackObjects();
+
+ const std::unique_ptr<const TrackerConfig> config_;
+
+ const int frame_width_;
+ const int frame_height_;
+
+ int64 curr_time_;
+
+ int num_frames_;
+
+ TrackedObjectMap objects_;
+
+ FlowCache flow_cache_;
+
+ KeypointDetector keypoint_detector_;
+
+ int curr_num_frame_pairs_;
+ int first_frame_index_;
+
+ std::unique_ptr<ImageData> frame1_;
+ std::unique_ptr<ImageData> frame2_;
+
+ FramePair frame_pairs_[kNumFrames];
+
+ std::unique_ptr<ObjectDetectorBase> detector_;
+
+ int num_detected_;
+
+ private:
+ void TrackTarget(TrackedObject* const object);
+
+ bool GetBestObjectForDetection(
+ const Detection& detection, TrackedObject** match) const;
+
+ void ProcessDetections(std::vector<Detection>* const detections);
+
+ void DetectTargets();
+
+ // Temp object used in ObjectTracker::CreateNewExample.
+ mutable std::vector<BoundingSquare> squares;
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker) {
+ stream << "Frame size: " << tracker.frame_width_ << "x"
+ << tracker.frame_height_ << std::endl;
+
+ stream << "Num frames: " << tracker.num_frames_ << std::endl;
+
+ stream << "Curr time: " << tracker.curr_time_ << std::endl;
+
+ const int first_frame_index = tracker.GetNthIndexFromStart(0);
+ const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
+
+ const int last_frame_index = tracker.GetNthIndexFromEnd(0);
+ const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
+
+ stream << "first frame: " << first_frame_index << ","
+ << first_frame_pair.end_time_ << " "
+ << "last frame: " << last_frame_index << ","
+ << last_frame_pair.end_time_ << " diff: "
+ << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
+ << std::endl;
+
+ stream << "Tracked targets:";
+ stream << tracker.objects_;
+
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc
new file mode 100644
index 0000000000..30c5974654
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc
@@ -0,0 +1,463 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <android/log.h>
+#include <jni.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <cstdint>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
+ Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
+
+JniIntField object_tracker_field("nativeObjectTracker");
+
+ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
+ ObjectTracker* const object_tracker =
+ reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
+ CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
+ return object_tracker;
+}
+
+void set_object_tracker(JNIEnv* env, jobject thiz,
+ const ObjectTracker* object_tracker) {
+ object_tracker_field.set(env, thiz,
+ reinterpret_cast<intptr_t>(object_tracker));
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
+ jint width, jint height,
+ jboolean always_track);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
+ jobject thiz);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jbyteArray frame_data);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jlong timestamp);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2);
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
+ jbyteArray y_data,
+ jbyteArray uv_data,
+ jlong timestamp,
+ jfloatArray vg_matrix_2x3);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
+ JNIEnv* env, jobject thiz, jfloat scale_factor);
+
+JNIEXPORT
+jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
+ JNIEnv* env, jobject thiz, jboolean only_found_);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
+ jfloat position_y1, jfloat position_x2, jfloat position_y2,
+ jfloatArray delta);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
+ jint view_width,
+ jint view_height,
+ jfloatArray delta);
+
+JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
+ JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
+ jbyteArray input, jint factor, jbyteArray output);
+
+#ifdef __cplusplus
+}
+#endif
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
+ jint width, jint height,
+ jboolean always_track) {
+ LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
+ const Size image_size(width, height);
+ TrackerConfig* const tracker_config = new TrackerConfig(image_size);
+ tracker_config->always_track = always_track;
+
+ // XXX detector
+ ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
+ set_object_tracker(env, thiz, tracker);
+ LOGI("Initialized!");
+
+ CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
+ "Failure to set hand tracker!");
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
+ jobject thiz) {
+ delete get_object_tracker(env, thiz);
+ set_object_tracker(env, thiz, NULL);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jbyteArray frame_data) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
+ x2, y2);
+
+ jboolean iCopied = JNI_FALSE;
+
+ // Copy image into currFrame.
+ jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
+
+ BoundingBox bounding_box(x1, y1, x2, y2);
+ get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
+ id_str, reinterpret_cast<const uint8*>(pixels), bounding_box);
+
+ env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jlong timestamp) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI(
+ "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
+ " at time %lld",
+ id_str, x1, y1, x2, y2, static_cast<int64>(timestamp));
+
+ get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
+ id_str, BoundingBox(x1, y1, x2, y2), timestamp);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
+ x2, y2);
+
+ get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
+ id_str, BoundingBox(x1, y1, x2, y2));
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return haveObject;
+}
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return visible;
+}
+
+JNIEXPORT
+jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+ const TrackedObject* const object =
+ get_object_tracker(env, thiz)->GetObject(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
+ return model_name;
+}
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const float correlation =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return correlation;
+}
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const float match_score =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return match_score;
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
+ jboolean iCopied = JNI_FALSE;
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const BoundingBox bounding_box =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
+ env->ReleaseStringUTFChars(object_id, id_str);
+
+ jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
+ bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
+ env->ReleaseFloatArrayElements(rect_array, rect, 0);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
+ jbyteArray y_data,
+ jbyteArray uv_data,
+ jlong timestamp,
+ jfloatArray vg_matrix_2x3) {
+ TimeLog("Starting object tracker");
+
+ jboolean iCopied = JNI_FALSE;
+
+ float vision_gyro_matrix_array[6];
+ jfloat* jmat = NULL;
+
+ if (vg_matrix_2x3 != NULL) {
+ // Copy the alignment matrix into a float array.
+ jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
+ for (int i = 0; i < 6; ++i) {
+ vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
+ }
+ }
+ // Copy image into currFrame.
+ jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
+ jbyte* uv_pixels =
+ uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
+
+ TimeLog("Got elements");
+
+ // Add the frame to the object tracker object.
+ get_object_tracker(env, thiz)->NextFrame(
+ reinterpret_cast<uint8*>(pixels), reinterpret_cast<uint8*>(uv_pixels),
+ timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
+
+ env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
+
+ if (uv_data != NULL) {
+ env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
+ }
+
+ if (vg_matrix_2x3 != NULL) {
+ env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
+ }
+
+ TimeLog("Released elements");
+
+ PrintTimeLog();
+ ResetTimeLog();
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ get_object_tracker(env, thiz)->ForgetTarget(id_str);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
+ JNIEnv* env, jobject thiz, jboolean only_found) {
+ jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
+
+ const int number_of_keypoints =
+ get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
+
+ // Create and return the array that will be passed back to Java.
+ jfloatArray keypoints =
+ env->NewFloatArray(number_of_keypoints * kKeypointStep);
+ if (keypoints == NULL) {
+ LOGE("null array!");
+ return NULL;
+ }
+ env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
+ keypoint_arr);
+
+ return keypoints;
+}
+
+JNIEXPORT
+jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
+ JNIEnv* env, jobject thiz, jfloat scale_factor) {
+ // 2 bytes to a uint16 and two pairs of xy coordinates per keypoint.
+ const int bytes_per_keypoint = sizeof(uint16) * 2 * 2;
+ jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
+
+ const int number_of_keypoints =
+ get_object_tracker(env, thiz)->GetKeypointsPacked(
+ reinterpret_cast<uint16*>(keypoint_arr), scale_factor);
+
+ // Create and return the array that will be passed back to Java.
+ jbyteArray keypoints =
+ env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
+
+ if (keypoints == NULL) {
+ LOGE("null array!");
+ return NULL;
+ }
+
+ env->SetByteArrayRegion(
+ keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
+
+ return keypoints;
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
+ jfloat position_y1, jfloat position_x2, jfloat position_y2,
+ jfloatArray delta) {
+ jfloat point_arr[4];
+
+ const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
+ BoundingBox(position_x1, position_y1, position_x2, position_y2),
+ timestamp);
+
+ new_position.CopyToArray(point_arr);
+ env->SetFloatArrayRegion(delta, 0, 4, point_arr);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
+ JNIEnv* env, jobject thiz, jint view_width, jint view_height,
+ jfloatArray frame_to_canvas_arr) {
+ ObjectTracker* object_tracker = get_object_tracker(env, thiz);
+ if (object_tracker != NULL) {
+ jfloat* frame_to_canvas =
+ env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
+
+ object_tracker->Draw(view_width, view_height, frame_to_canvas);
+ env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
+ JNI_ABORT);
+ }
+}
+
+JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
+ JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
+ jbyteArray input, jint factor, jbyteArray output) {
+ if (input == NULL || output == NULL) {
+ LOGW("Received null arrays, hopefully this is a test!");
+ return;
+ }
+
+ jbyte* const input_array = env->GetByteArrayElements(input, 0);
+ jbyte* const output_array = env->GetByteArrayElements(output, 0);
+
+ {
+ tf_tracking::Image<uint8> full_image(
+ width, height, reinterpret_cast<uint8*>(input_array), false);
+
+ const int new_width = (width + factor - 1) / factor;
+ const int new_height = (height + factor - 1) / factor;
+
+ tf_tracking::Image<uint8> downsampled_image(
+ new_width, new_height, reinterpret_cast<uint8*>(output_array), false);
+
+ downsampled_image.DownsampleAveraged(reinterpret_cast<uint8*>(input_array),
+ row_stride, factor);
+ }
+
+ env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
+ env->ReleaseByteArrayElements(output, output_array, 0);
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc
new file mode 100644
index 0000000000..fab0a3155d
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc
@@ -0,0 +1,490 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
+ : config_(config),
+ frame1_(NULL),
+ frame2_(NULL),
+ working_size_(config->image_size) {}
+
+
+void OpticalFlow::NextFrame(const ImageData* const image_data) {
+ // Special case for the first frame: make sure the image ends up in
+ // frame1_ so that keypoint detection can be done on it if desired.
+ frame1_ = (frame1_ == NULL) ? image_data : frame2_;
+ frame2_ = image_data;
+}
+
+
+// Static heart of the optical flow computation.
+// Lucas Kanade algorithm.
+bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y) {
+ float g_x = *out_g_x;
+ float g_y = *out_g_y;
+ // Get values for frame 1. They remain constant through the inner
+ // iteration loop.
+ float vals_I[kFlowArraySize];
+ float vals_I_x[kFlowArraySize];
+ float vals_I_y[kFlowArraySize];
+
+ const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
+ const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
+
+#if USE_FIXED_POINT_FLOW
+ const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
+ const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
+#else
+ const float real_x_max = I_x.width_less_one_ - EPSILON;
+ const float real_y_max = I_x.height_less_one_ - EPSILON;
+#endif
+
+ // Get the window around the original point.
+ const float src_left_real = p_x - kWindowSizeFloat;
+ const float src_top_real = p_y - kWindowSizeFloat;
+ float* vals_I_ptr = vals_I;
+ float* vals_I_x_ptr = vals_I_x;
+ float* vals_I_y_ptr = vals_I_y;
+#if USE_FIXED_POINT_FLOW
+ // Source integer coordinates.
+ const int src_left_fixed = RealToFixed1616(src_left_real);
+ const int src_top_fixed = RealToFixed1616(src_top_real);
+
+ for (int y = 0; y < kPatchSize; ++y) {
+ const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
+
+ for (int x = 0; x < kPatchSize; ++x) {
+ const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
+
+ *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+#else
+ for (int y = 0; y < kPatchSize; ++y) {
+ const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
+
+ for (int x = 0; x < kPatchSize; ++x) {
+ const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
+
+ *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
+ *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
+ *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
+ }
+ }
+#endif
+
+ // Compute the spatial gradient matrix about point p.
+ float G[] = { 0, 0, 0, 0 };
+ CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
+
+ // Find the inverse of G.
+ float G_inv[4];
+ if (!Invert2x2(G, G_inv)) {
+ return false;
+ }
+
+#if NORMALIZE
+ const float mean_I = ComputeMean(vals_I, kFlowArraySize);
+ const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
+#endif
+
+ // Iterate kNumIterations times or until we converge.
+ for (int iteration = 0; iteration < kNumIterations; ++iteration) {
+ // Get values for frame 2.
+ float vals_J[kFlowArraySize];
+
+ // Get the window around the destination point.
+ const float left_real = p_x + g_x - kWindowSizeFloat;
+ const float top_real = p_y + g_y - kWindowSizeFloat;
+ float* vals_J_ptr = vals_J;
+#if USE_FIXED_POINT_FLOW
+ // The top-left sub-pixel is set for the current iteration (in 16:16
+ // fixed). This is constant over one iteration.
+ const int left_fixed = RealToFixed1616(left_real);
+ const int top_fixed = RealToFixed1616(top_real);
+
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+ const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
+ *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+#else
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+ const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
+ *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
+ }
+ }
+#endif
+
+#if NORMALIZE
+ const float mean_J = ComputeMean(vals_J, kFlowArraySize);
+ const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
+
+ // TODO(andrewharp): Probably better to completely detect and handle the
+ // "corner case" where the patch is fully outside the image diagonally.
+ const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
+#endif
+
+ // Compute image mismatch vector.
+ float b_x = 0.0f;
+ float b_y = 0.0f;
+
+ vals_I_ptr = vals_I;
+ vals_J_ptr = vals_J;
+ vals_I_x_ptr = vals_I_x;
+ vals_I_y_ptr = vals_I_y;
+
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+#if NORMALIZE
+ // Normalized Image difference.
+ const float dI =
+ (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
+#else
+ const float dI = *vals_I_ptr++ - *vals_J_ptr++;
+#endif
+ b_x += dI * *vals_I_x_ptr++;
+ b_y += dI * *vals_I_y_ptr++;
+ }
+ }
+
+ // Optical flow... solve n = G^-1 * b
+ const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
+ const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
+
+ // Update best guess with residual displacement from this level and
+ // iteration.
+ g_x += n_x;
+ g_y += n_y;
+
+ // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
+
+ // Abort early if we're already below the threshold.
+ if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
+ break;
+ }
+ } // Iteration.
+
+ // Copy value back into output.
+ *out_g_x = g_x;
+ *out_g_y = g_y;
+ return true;
+}
+
+
+// Pointwise flow using translational 2dof ESM.
+bool OpticalFlow::FindFlowAtPoint_ESM(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const Image<int32>& J_x,
+ const Image<int32>& J_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y) {
+ float g_x = *out_g_x;
+ float g_y = *out_g_y;
+ const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
+
+ // Get values for frame 1. They remain constant through the inner
+ // iteration loop.
+ uint8 vals_I[kFlowArraySize];
+ uint8 vals_J[kFlowArraySize];
+ int16 src_gradient_x[kFlowArraySize];
+ int16 src_gradient_y[kFlowArraySize];
+
+ // TODO(rspring): try out the IntegerPatchAlign() method once
+ // the code for that is in ../common.
+ const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
+ const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
+ const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
+ const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
+
+ // Create the keypoint template patch from a subpixel location.
+ if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size, vals_I) ||
+ !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size,
+ src_gradient_x) ||
+ !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size,
+ src_gradient_y)) {
+ return false;
+ }
+
+ int bright_offset = 0;
+ int sum_diff = 0;
+
+ // The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
+ // This is constant over one iteration.
+ int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
+ int top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
+
+ // The truncated version gives the most top-left pixel that is used.
+ int left_trunc = left_fixed >> 16;
+ int top_trunc = top_fixed >> 16;
+
+ // Compute an initial brightness offset.
+ if (kDoBrightnessNormalize &&
+ left_trunc >= 0 && top_trunc >= 0 &&
+ (left_trunc + patch_size) < img_J.width_less_one_ &&
+ (top_trunc + patch_size) < img_J.height_less_one_) {
+ int templ_index = 0;
+ const uint8* j_row = img_J[top_trunc] + left_trunc;
+
+ const int j_stride = img_J.stride();
+
+ for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
+ for (int x = 0; x < patch_size; ++x) {
+ sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
+ }
+ }
+
+ bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
+ }
+
+ // Iterate kNumIterations times or until we go out of image.
+ for (int iteration = 0; iteration < kNumIterations; ++iteration) {
+ int jtj[3] = { 0, 0, 0 };
+ int jtr[2] = { 0, 0 };
+ sum_diff = 0;
+
+ // Extract the target image values.
+ // Extract the gradient from the target image patch and accumulate to
+ // the gradient of the source image patch.
+ if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
+ patch_size, patch_size,
+ vals_J)) {
+ break;
+ }
+
+ const uint8* templ_row = vals_I;
+ const uint8* extract_row = vals_J;
+ const int16* src_dx_row = src_gradient_x;
+ const int16* src_dy_row = src_gradient_y;
+
+ for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
+ src_dx_row += patch_size, src_dy_row += patch_size,
+ extract_row += patch_size) {
+ const int fp_y = top_fixed + (y << 16);
+ for (int x = 0; x < patch_size; ++x) {
+ const int fp_x = left_fixed + (x << 16);
+ int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
+
+ // Combine the two Jacobians.
+ // Right-shift by one to account for the fact that we add
+ // two Jacobians.
+ int32 dx = (src_dx_row[x] + target_dx) >> 1;
+ int32 dy = (src_dy_row[x] + target_dy) >> 1;
+
+ // The current residual b - h(q) == extracted - (template + offset)
+ int32 diff = static_cast<int32>(extract_row[x]) -
+ static_cast<int32>(templ_row[x]) -
+ bright_offset;
+
+ jtj[0] += dx * dx;
+ jtj[1] += dx * dy;
+ jtj[2] += dy * dy;
+
+ jtr[0] += dx * diff;
+ jtr[1] += dy * diff;
+
+ sum_diff += diff;
+ }
+ }
+
+ const float jtr1_float = static_cast<float>(jtr[0]);
+ const float jtr2_float = static_cast<float>(jtr[1]);
+
+ // Add some baseline stability to the system.
+ jtj[0] += kEsmRegularizer;
+ jtj[2] += kEsmRegularizer;
+
+ const int64 prod1 = static_cast<int64>(jtj[0]) * jtj[2];
+ const int64 prod2 = static_cast<int64>(jtj[1]) * jtj[1];
+
+ // One ESM step.
+ const float jtj_1[4] = { static_cast<float>(jtj[2]),
+ static_cast<float>(-jtj[1]),
+ static_cast<float>(-jtj[1]),
+ static_cast<float>(jtj[0]) };
+ const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
+
+ g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
+ g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
+
+ if (kDoBrightnessNormalize) {
+ bright_offset +=
+ static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
+ }
+
+ // Update top left position.
+ left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
+ top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
+
+ left_trunc = left_fixed >> 16;
+ top_trunc = top_fixed >> 16;
+
+ // Abort iterations if we go out of borders.
+ if (left_trunc < 0 || top_trunc < 0 ||
+ (left_trunc + patch_size) >= J_x.width_less_one_ ||
+ (top_trunc + patch_size) >= J_y.height_less_one_) {
+ break;
+ }
+ } // Iteration.
+
+ // Copy value back into output.
+ *out_g_x = g_x;
+ *out_g_y = g_y;
+ return true;
+}
+
+
+bool OpticalFlow::FindFlowAtPointReversible(
+ const int level, const float u_x, const float u_y,
+ const bool reverse_flow,
+ float* flow_x, float* flow_y) const {
+ const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
+ const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
+
+ // Images I (prev) and J (next).
+ const Image<uint8>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
+ const Image<uint8>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
+
+ // Computed gradients.
+ const Image<int32>& I_x = *frame_a.GetSpatialX(level);
+ const Image<int32>& I_y = *frame_a.GetSpatialY(level);
+ const Image<int32>& J_x = *frame_b.GetSpatialX(level);
+ const Image<int32>& J_y = *frame_b.GetSpatialY(level);
+
+ // Shrink factor from original.
+ const float shrink_factor = (1 << level);
+
+ // Image position vector (p := u^l), scaled for this level.
+ const float scaled_p_x = u_x / shrink_factor;
+ const float scaled_p_y = u_y / shrink_factor;
+
+ float scaled_flow_x = *flow_x / shrink_factor;
+ float scaled_flow_y = *flow_y / shrink_factor;
+
+ // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
+ // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
+
+ const bool success = kUseEsm ?
+ FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
+ scaled_p_x, scaled_p_y,
+ &scaled_flow_x, &scaled_flow_y) :
+ FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
+ scaled_p_x, scaled_p_y,
+ &scaled_flow_x, &scaled_flow_y);
+
+ *flow_x = scaled_flow_x * shrink_factor;
+ *flow_y = scaled_flow_y * shrink_factor;
+
+ return success;
+}
+
+
+bool OpticalFlow::FindFlowAtPointSingleLevel(
+ const int level,
+ const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const {
+ if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
+ return false;
+ }
+
+ if (filter_by_fb_error) {
+ const float new_position_x = u_x + *flow_x;
+ const float new_position_y = u_y + *flow_y;
+
+ float reverse_flow_x = 0.0f;
+ float reverse_flow_y = 0.0f;
+
+ // Now find the backwards flow and confirm it lines up with the original
+ // starting point.
+ if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
+ true,
+ &reverse_flow_x, &reverse_flow_y)) {
+ LOGE("Backward error!");
+ return false;
+ }
+
+ const float discrepancy_length =
+ sqrtf(Square(*flow_x + reverse_flow_x) +
+ Square(*flow_y + reverse_flow_y));
+
+ const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
+
+ return discrepancy_length <
+ (kMaxForwardBackwardErrorAllowed * flow_length);
+ }
+
+ return true;
+}
+
+
+// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
+bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const {
+ const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
+ kNumPyramidLevels - kNumCacheLevels);
+
+ // For every level in the pyramid, update the coordinates of the best match.
+ for (int l = max_level - 1; l >= 0; --l) {
+ if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
+ filter_by_fb_error, flow_x, flow_y)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
new file mode 100644
index 0000000000..1329927b99
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
@@ -0,0 +1,111 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+class FlowCache;
+
+// Class encapsulating all the data and logic necessary for performing optical
+// flow.
+class OpticalFlow {
+ public:
+ explicit OpticalFlow(const OpticalFlowConfig* const config);
+
+ // Add a new frame to the optical flow. Will update all the non-keypoint
+ // related member variables.
+ //
+ // new_frame should be a buffer of grayscale values, one byte per pixel,
+ // at the original frame_width and frame_height used to initialize the
+ // OpticalFlow object. Downsampling will be handled internally.
+ //
+ // time_stamp should be a time in milliseconds that later calls to this and
+ // other methods will be relative to.
+ void NextFrame(const ImageData* const image_data);
+
+ // An implementation of the Lucas-Kanade Optical Flow algorithm.
+ static bool FindFlowAtPoint_LK(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y);
+
+ // Pointwise flow using translational 2dof ESM.
+ static bool FindFlowAtPoint_ESM(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const Image<int32>& J_x,
+ const Image<int32>& J_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y);
+
+ // Finds the flow using a specific level, in either direction.
+ // If reversed, the coordinates are in the context of the latest
+ // frame, not the frame before it.
+ // All coordinates used in parameters are global, not scaled.
+ bool FindFlowAtPointReversible(
+ const int level, const float u_x, const float u_y,
+ const bool reverse_flow,
+ float* final_x, float* final_y) const;
+
+ // Finds the flow using a specific level, filterable by forward-backward
+ // error. All coordinates used in parameters are global, not scaled.
+ bool FindFlowAtPointSingleLevel(const int level,
+ const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const;
+
+ // Pyramidal optical-flow using all levels.
+ bool FindFlowAtPointPyramidal(const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const;
+
+ private:
+ const OpticalFlowConfig* const config_;
+
+ const ImageData* frame1_;
+ const ImageData* frame2_;
+
+ // Size of the internally allocated images (after original is downsampled).
+ const Size working_size_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h
new file mode 100755
index 0000000000..6240591cf2
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/sprite.h
@@ -0,0 +1,205 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+
+#ifndef __RENDER_OPENGL__
+#error sprite.h should not included if OpenGL is not enabled by platform.h
+#endif
+
+namespace tf_tracking {
+
+// This class encapsulates the logic necessary to load an render image data
+// at the same aspect ratio as the original source.
+class Sprite {
+ public:
+ // Only create Sprites when you have an OpenGl context.
+ explicit Sprite(const Image<uint8>& image) {
+ LoadTexture(image, NULL);
+ }
+
+ Sprite(const Image<uint8>& image, const BoundingBox* const area) {
+ LoadTexture(image, area);
+ }
+
+ // Also, try to only delete a Sprite when holding an OpenGl context.
+ ~Sprite() {
+ glDeleteTextures(1, &texture_);
+ }
+
+ inline int GetWidth() const {
+ return actual_width_;
+ }
+
+ inline int GetHeight() const {
+ return actual_height_;
+ }
+
+ // Draw the sprite at 0,0 - original width/height in the current reference
+ // frame. Any transformations desired must be applied before calling this
+ // function.
+ void Draw() const {
+ const float float_width = static_cast<float>(actual_width_);
+ const float float_height = static_cast<float>(actual_height_);
+
+ // Where it gets rendered to.
+ const float vertices[] = { 0.0f, 0.0f, 0.0f,
+ 0.0f, float_height, 0.0f,
+ float_width, 0.0f, 0.0f,
+ float_width, float_height, 0.0f,
+ };
+
+ // The coordinates the texture gets drawn from.
+ const float max_x = float_width / texture_width_;
+ const float max_y = float_height / texture_height_;
+ const float textureVertices[] = {
+ 0, 0,
+ 0, max_y,
+ max_x, 0,
+ max_x, max_y,
+ };
+
+ glEnable(GL_TEXTURE_2D);
+ glBindTexture(GL_TEXTURE_2D, texture_);
+
+ glEnableClientState(GL_VERTEX_ARRAY);
+ glEnableClientState(GL_TEXTURE_COORD_ARRAY);
+
+ glVertexPointer(3, GL_FLOAT, 0, vertices);
+ glTexCoordPointer(2, GL_FLOAT, 0, textureVertices);
+
+ glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
+
+ glDisableClientState(GL_VERTEX_ARRAY);
+ glDisableClientState(GL_TEXTURE_COORD_ARRAY);
+ }
+
+ private:
+ inline int GetNextPowerOfTwo(const int number) const {
+ int power_of_two = 1;
+ while (power_of_two < number) {
+ power_of_two *= 2;
+ }
+ return power_of_two;
+ }
+
+ // TODO(andrewharp): Allow sprites to have their textures reloaded.
+ void LoadTexture(const Image<uint8>& texture_source,
+ const BoundingBox* const area) {
+ glEnable(GL_TEXTURE_2D);
+
+ glGenTextures(1, &texture_);
+
+ glBindTexture(GL_TEXTURE_2D, texture_);
+
+ int left = 0;
+ int top = 0;
+
+ if (area != NULL) {
+ // If a sub-region was provided to pull the texture from, use that.
+ left = area->left_;
+ top = area->top_;
+ actual_width_ = area->GetWidth();
+ actual_height_ = area->GetHeight();
+ } else {
+ actual_width_ = texture_source.GetWidth();
+ actual_height_ = texture_source.GetHeight();
+ }
+
+ // The textures must be a power of two, so find the sizes that are large
+ // enough to contain the image data.
+ texture_width_ = GetNextPowerOfTwo(actual_width_);
+ texture_height_ = GetNextPowerOfTwo(actual_height_);
+
+ bool allocated_data = false;
+ uint8* texture_data;
+
+ // Except in the lucky case where we're not using a sub-region of the
+ // original image AND the source data has dimensions that are power of two,
+ // care must be taken to copy data at the appropriate source and destination
+ // strides so that the final block can be copied directly into texture
+ // memory.
+ // TODO(andrewharp): Figure out if data can be pulled directly from the
+ // source image with some alignment modifications.
+ if (left != 0 || top != 0 ||
+ actual_width_ != texture_source.GetWidth() ||
+ actual_height_ != texture_source.GetHeight()) {
+ texture_data = new uint8[actual_width_ * actual_height_];
+
+ for (int y = 0; y < actual_height_; ++y) {
+ memcpy(texture_data + actual_width_ * y,
+ texture_source[top + y] + left,
+ actual_width_ * sizeof(uint8));
+ }
+ allocated_data = true;
+ } else {
+ // Cast away const-ness because for some reason glTexSubImage2D wants
+ // a non-const data pointer.
+ texture_data = const_cast<uint8*>(texture_source.data());
+ }
+
+ glTexImage2D(GL_TEXTURE_2D,
+ 0,
+ GL_LUMINANCE,
+ texture_width_,
+ texture_height_,
+ 0,
+ GL_LUMINANCE,
+ GL_UNSIGNED_BYTE,
+ NULL);
+
+ glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
+ glTexSubImage2D(GL_TEXTURE_2D,
+ 0,
+ 0,
+ 0,
+ actual_width_,
+ actual_height_,
+ GL_LUMINANCE,
+ GL_UNSIGNED_BYTE,
+ texture_data);
+
+ if (allocated_data) {
+ delete(texture_data);
+ }
+
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
+ }
+
+ // The id for the texture on the GPU.
+ GLuint texture_;
+
+ // The width and height to be used for display purposes, referring to the
+ // dimensions of the original texture.
+ int actual_width_;
+ int actual_height_;
+
+ // The allocated dimensions of the texture data, which must be powers of 2.
+ int texture_width_;
+ int texture_height_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Sprite);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.cc b/tensorflow/examples/android/jni/object_tracking/time_log.cc
new file mode 100644
index 0000000000..cb1f3c23c8
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/time_log.cc
@@ -0,0 +1,29 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+
+using namespace tensorflow;
+
+#ifdef LOG_TIME
+// Storage for logging functionality.
+int num_time_logs = 0;
+LogEntry time_logs[NUM_LOGS];
+
+int num_avg_entries = 0;
+AverageEntry avg_entries[NUM_LOGS];
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h
new file mode 100644
index 0000000000..ec539a1b3b
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/time_log.h
@@ -0,0 +1,138 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility functions for performance profiling.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#ifdef LOG_TIME
+
+// Blend constant for running average.
+#define ALPHA 0.98f
+#define NUM_LOGS 100
+
+struct LogEntry {
+ const char* id;
+ int64 time_stamp;
+};
+
+struct AverageEntry {
+ const char* id;
+ float average_duration;
+};
+
+// Storage for keeping track of this frame's values.
+extern int num_time_logs;
+extern LogEntry time_logs[NUM_LOGS];
+
+// Storage for keeping track of average values (each entry may not be printed
+// out each frame).
+extern AverageEntry avg_entries[NUM_LOGS];
+extern int num_avg_entries;
+
+// Call this at the start of a logging phase.
+inline static void ResetTimeLog() {
+ num_time_logs = 0;
+}
+
+
+// Log a message to be printed out when printTimeLog is called, along with the
+// amount of time in ms that has passed since the last call to this function.
+inline static void TimeLog(const char* const str) {
+ LOGV("%s", str);
+ if (num_time_logs >= NUM_LOGS) {
+ LOGE("Out of log entries!");
+ return;
+ }
+
+ time_logs[num_time_logs].id = str;
+ time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos();
+ ++num_time_logs;
+}
+
+
+inline static float Blend(float old_val, float new_val) {
+ return ALPHA * old_val + (1.0f - ALPHA) * new_val;
+}
+
+
+inline static float UpdateAverage(const char* str, const float new_val) {
+ for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) {
+ AverageEntry* const entry = avg_entries + entry_num;
+ if (str == entry->id) {
+ entry->average_duration = Blend(entry->average_duration, new_val);
+ return entry->average_duration;
+ }
+ }
+
+ if (num_avg_entries >= NUM_LOGS) {
+ LOGE("Too many log entries!");
+ }
+
+ // If it wasn't there already, add it.
+ avg_entries[num_avg_entries].id = str;
+ avg_entries[num_avg_entries].average_duration = new_val;
+ ++num_avg_entries;
+
+ return new_val;
+}
+
+
+// Prints out all the timeLog statements in chronological order with the
+// interval that passed between subsequent statements. The total time between
+// the first and last statements is printed last.
+inline static void PrintTimeLog() {
+ LogEntry* last_time = time_logs;
+
+ float average_running_total = 0.0f;
+
+ for (int i = 0; i < num_time_logs; ++i) {
+ LogEntry* const this_time = time_logs + i;
+
+ const float curr_time =
+ (this_time->time_stamp - last_time->time_stamp) / 1000000.0f;
+
+ const float avg_time = UpdateAverage(this_time->id, curr_time);
+ average_running_total += avg_time;
+
+ LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time);
+ last_time = this_time;
+ }
+
+ const float total_time =
+ (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f;
+
+ LOGD("TOTAL TIME: %6.3fms %6.4fms\n",
+ total_time, average_running_total);
+ LOGD(" ");
+}
+#else
+inline static void ResetTimeLog() {}
+
+inline static void TimeLog(const char* const str) {
+ LOGV("%s", str);
+}
+
+inline static void PrintTimeLog() {}
+#endif
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
new file mode 100644
index 0000000000..823fb3a90e
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
@@ -0,0 +1,163 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+static const float kInitialDistance = 20.0f;
+
+static void InitNormalized(const Image<uint8>& src_image,
+ const BoundingBox& position,
+ Image<float>* const dst_image) {
+ BoundingBox scaled_box(position);
+ CopyArea(src_image, scaled_box, dst_image);
+ NormalizeImage(dst_image);
+}
+
+TrackedObject::TrackedObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ ObjectModelBase* const model)
+ : id_(id),
+ last_known_position_(bounding_box),
+ last_detection_position_(bounding_box),
+ position_last_computed_time_(-1),
+ object_model_(model),
+ last_detection_thumbnail_(kNormalizedThumbnailSize,
+ kNormalizedThumbnailSize),
+ last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
+ tracked_correlation_(0.0f),
+ tracked_match_score_(0.0),
+ num_consecutive_frames_below_threshold_(0),
+ allowable_detection_distance_(Square(kInitialDistance)) {
+ InitNormalized(image, bounding_box, &last_detection_thumbnail_);
+}
+
+TrackedObject::~TrackedObject() {}
+
+void TrackedObject::UpdatePosition(const BoundingBox& new_position,
+ const int64 timestamp,
+ const ImageData& image_data,
+ const bool authoratative) {
+ last_known_position_ = new_position;
+ position_last_computed_time_ = timestamp;
+
+ InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
+
+ const float last_localization_correlation = ComputeCrossCorrelation(
+ last_detection_thumbnail_.data(),
+ last_frame_thumbnail_.data(),
+ last_frame_thumbnail_.data_size_);
+ LOGV("Tracked correlation to last localization: %.6f",
+ last_localization_correlation);
+
+ // Correlation to object model, if it exists.
+ if (object_model_ != NULL) {
+ tracked_correlation_ =
+ object_model_->GetMaxCorrelation(last_frame_thumbnail_);
+ LOGV("Tracked correlation to model: %.6f",
+ tracked_correlation_);
+
+ tracked_match_score_ =
+ object_model_->GetMatchScore(new_position, image_data);
+ LOGV("Tracked match score with model: %.6f",
+ tracked_match_score_.value);
+ } else {
+ // If there's no model to check against, set the tracked correlation to
+ // simply be the correlation to the last set position.
+ tracked_correlation_ = last_localization_correlation;
+ tracked_match_score_ = MatchScore(0.0f);
+ }
+
+ // Determine if it's still being tracked.
+ if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
+ tracked_match_score_ >= kMinimumMatchScore) {
+ num_consecutive_frames_below_threshold_ = 0;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
+ *image_data.GetIntegralImage(), authoratative);
+ }
+ } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
+ if (num_consecutive_frames_below_threshold_ < 1000) {
+ LOGD("Tracked match score is way too low (%.6f), aborting track.",
+ tracked_match_score_.value);
+ }
+
+ // Add an absurd amount of missed frames so that all heuristics will
+ // consider it a lost track.
+ num_consecutive_frames_below_threshold_ += 1000;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackLost();
+ }
+ } else {
+ ++num_consecutive_frames_below_threshold_;
+ allowable_detection_distance_ *= 1.1f;
+ }
+}
+
+void TrackedObject::OnDetection(ObjectModelBase* const model,
+ const BoundingBox& detection_position,
+ const MatchScore match_score,
+ const int64 timestamp,
+ const ImageData& image_data) {
+ const float overlap = detection_position.PascalScore(last_known_position_);
+ if (overlap > kPositionOverlapThreshold) {
+ // If the position agreement with the current tracked position is good
+ // enough, lock all the current unlocked examples.
+ object_model_->TrackConfirmed();
+ num_consecutive_frames_below_threshold_ = 0;
+ }
+
+ // Before relocalizing, make sure the new proposed position is better than
+ // the existing position by a small amount to prevent thrashing.
+ if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
+ LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
+ match_score.value, tracked_match_score_.value,
+ kMatchScoreBuffer.value);
+ return;
+ }
+
+ LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
+ "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
+ last_known_position_.left_, last_known_position_.top_,
+ last_known_position_.GetWidth(), last_known_position_.GetHeight(),
+ detection_position.left_, detection_position.top_,
+ detection_position.GetWidth(), detection_position.GetHeight(),
+ match_score.value, tracked_match_score_.value);
+
+ if (overlap < kPositionOverlapThreshold) {
+ // The path might be good, it might be bad, but it's no longer a path
+ // since we're moving the box to a new position, so just nuke it from
+ // orbit to be safe.
+ object_model_->TrackLost();
+ }
+
+ object_model_ = model;
+
+ // Reset the last detected appearance.
+ InitNormalized(
+ *image_data.GetImage(), detection_position, &last_detection_thumbnail_);
+
+ num_consecutive_frames_below_threshold_ = 0;
+ last_detection_position_ = detection_position;
+
+ UpdatePosition(detection_position, timestamp, image_data, false);
+ allowable_detection_distance_ = Square(kInitialDistance);
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
new file mode 100644
index 0000000000..5580cd2b89
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
@@ -0,0 +1,191 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+
+namespace tf_tracking {
+
+// A TrackedObject is a specific instance of an ObjectModel, with a known
+// position in the world.
+// It provides the last known position and number of recent detection failures,
+// in addition to the more general appearance data associated with the object
+// class (which is in ObjectModel).
+// TODO(andrewharp): Make getters/setters follow styleguide.
+class TrackedObject {
+ public:
+ TrackedObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ ObjectModelBase* const model);
+
+ ~TrackedObject();
+
+ void UpdatePosition(const BoundingBox& new_position,
+ const int64 timestamp,
+ const ImageData& image_data,
+ const bool authoratative);
+
+ // This method is called when the tracked object is detected at a
+ // given position, and allows the associated Model to grow and/or prune
+ // itself based on where the detection occurred.
+ void OnDetection(ObjectModelBase* const model,
+ const BoundingBox& detection_position,
+ const MatchScore match_score,
+ const int64 timestamp,
+ const ImageData& image_data);
+
+ // Called when there's no detection of the tracked object. This will cause
+ // a tracking failure after enough consecutive failures if the area under
+ // the current bounding box also doesn't meet a minimum correlation threshold
+ // with the model.
+ void OnDetectionFailure() {}
+
+ inline bool IsVisible() const {
+ return tracked_correlation_ >= kMinimumCorrelationForTracking ||
+ num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
+ }
+
+ inline float GetCorrelation() {
+ return tracked_correlation_;
+ }
+
+ inline MatchScore GetMatchScore() {
+ return tracked_match_score_;
+ }
+
+ inline BoundingBox GetPosition() const {
+ return last_known_position_;
+ }
+
+ inline BoundingBox GetLastDetectionPosition() const {
+ return last_detection_position_;
+ }
+
+ inline const ObjectModelBase* GetModel() const {
+ return object_model_;
+ }
+
+ inline const std::string& GetName() const {
+ return id_;
+ }
+
+ inline void Draw() const {
+#ifdef __RENDER_OPENGL__
+ if (tracked_correlation_ < kMinimumCorrelationForTracking) {
+ glColor4f(MAX(0.0f, -tracked_correlation_),
+ MAX(0.0f, tracked_correlation_),
+ 0.0f,
+ 1.0f);
+ } else {
+ glColor4f(MAX(0.0f, -tracked_correlation_),
+ MAX(0.0f, tracked_correlation_),
+ 1.0f,
+ 1.0f);
+ }
+
+ // Render the box itself.
+ BoundingBox temp_box(last_known_position_);
+ DrawBox(temp_box);
+
+ // Render a box inside this one (in case the actual box is hidden).
+ const float kBufferSize = 1.0f;
+ temp_box.left_ -= kBufferSize;
+ temp_box.top_ -= kBufferSize;
+ temp_box.right_ += kBufferSize;
+ temp_box.bottom_ += kBufferSize;
+ DrawBox(temp_box);
+
+ // Render one outside as well.
+ temp_box.left_ -= -2.0f * kBufferSize;
+ temp_box.top_ -= -2.0f * kBufferSize;
+ temp_box.right_ += -2.0f * kBufferSize;
+ temp_box.bottom_ += -2.0f * kBufferSize;
+ DrawBox(temp_box);
+#endif
+ }
+
+ // Get current object's num_consecutive_frames_below_threshold_.
+ inline int64 GetNumConsecutiveFramesBelowThreshold() {
+ return num_consecutive_frames_below_threshold_;
+ }
+
+ // Reset num_consecutive_frames_below_threshold_ to 0.
+ inline void resetNumConsecutiveFramesBelowThreshold() {
+ num_consecutive_frames_below_threshold_ = 0;
+ }
+
+ inline float GetAllowableDistanceSquared() const {
+ return allowable_detection_distance_;
+ }
+
+ private:
+ // The unique id used throughout the system to identify this
+ // tracked object.
+ const std::string id_;
+
+ // The last known position of the object.
+ BoundingBox last_known_position_;
+
+ // The last known position of the object.
+ BoundingBox last_detection_position_;
+
+ // When the position was last computed.
+ int64 position_last_computed_time_;
+
+ // The object model this tracked object is representative of.
+ ObjectModelBase* object_model_;
+
+ Image<float> last_detection_thumbnail_;
+
+ Image<float> last_frame_thumbnail_;
+
+ // The correlation of the object model with the preview frame at its last
+ // tracked position.
+ float tracked_correlation_;
+
+ MatchScore tracked_match_score_;
+
+ // The number of consecutive frames that the tracked position for this object
+ // has been under the correlation threshold.
+ int num_consecutive_frames_below_threshold_;
+
+ float allowable_detection_distance_;
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const TrackedObject& tracked_object);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TrackedObject& tracked_object) {
+ stream << tracked_object.id_
+ << " " << tracked_object.last_known_position_
+ << " " << tracked_object.position_last_computed_time_
+ << " " << tracked_object.num_consecutive_frames_below_threshold_
+ << " " << tracked_object.object_model_
+ << " " << tracked_object.tracked_correlation_;
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h
new file mode 100644
index 0000000000..cbdfc408c6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/utils.h
@@ -0,0 +1,386 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+
+#include <math.h>
+#include <stdlib.h>
+#include <time.h>
+
+#include <cmath> // for std::abs(float)
+
+#ifndef HAVE_CLOCK_GETTIME
+// Use gettimeofday() instead of clock_gettime().
+#include <sys/time.h>
+#endif // ifdef HAVE_CLOCK_GETTIME
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+using namespace tensorflow;
+
+// TODO(andrewharp): clean up these macros to use the codebase statndard.
+
+// A very small number, generally used as the tolerance for accumulated
+// floating point errors in bounds-checks.
+#define EPSILON 0.00001f
+
+#define SAFE_DELETE(pointer) {\
+ if ((pointer) != NULL) {\
+ LOGV("Safe deleting pointer: %s", #pointer);\
+ delete (pointer);\
+ (pointer) = NULL;\
+ } else {\
+ LOGV("Pointer already null: %s", #pointer);\
+ }\
+}
+
+
+#ifdef __GOOGLE__
+
+#define CHECK_ALWAYS(condition, format, ...) {\
+ CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
+}
+
+#define SCHECK(condition, format, ...) {\
+ DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
+}
+
+#else
+
+#define CHECK_ALWAYS(condition, format, ...) {\
+ if (!(condition)) {\
+ LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\
+ abort();\
+ }\
+}
+
+#ifdef SANITY_CHECKS
+#define SCHECK(condition, format, ...) {\
+ CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\
+}
+#else
+#define SCHECK(condition, format, ...) {}
+#endif // SANITY_CHECKS
+
+#endif // __GOOGLE__
+
+
+#ifndef MAX
+#define MAX(a, b) (((a) > (b)) ? (a) : (b))
+#endif
+#ifndef MIN
+#define MIN(a, b) (((a) > (b)) ? (b) : (a))
+#endif
+
+
+
+inline static int64 CurrentThreadTimeNanos() {
+#ifdef HAVE_CLOCK_GETTIME
+ struct timespec tm;
+ clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm);
+ return tm.tv_sec * 1000000000LL + tm.tv_nsec;
+#else
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return tv.tv_sec * 1000000000 + tv.tv_usec * 1000;
+#endif
+}
+
+
+inline static int64 CurrentRealTimeMillis() {
+#ifdef HAVE_CLOCK_GETTIME
+ struct timespec tm;
+ clock_gettime(CLOCK_MONOTONIC, &tm);
+ return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL;
+#else
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return tv.tv_sec * 1000 + tv.tv_usec / 1000;
+#endif
+}
+
+
+template<typename T>
+inline static T Square(const T a) {
+ return a * a;
+}
+
+
+template<typename T>
+inline static T Clip(const T a, const T floor, const T ceil) {
+ SCHECK(ceil >= floor, "Bounds mismatch!");
+ return (a <= floor) ? floor : ((a >= ceil) ? ceil : a);
+}
+
+
+template<typename T>
+inline static int Floor(const T a) {
+ return static_cast<int>(a);
+}
+
+
+template<typename T>
+inline static int Ceil(const T a) {
+ return Floor(a) + 1;
+}
+
+
+template<typename T>
+inline static bool InRange(const T a, const T min, const T max) {
+ return (a >= min) && (a <= max);
+}
+
+
+inline static bool ValidIndex(const int a, const int max) {
+ return (a >= 0) && (a < max);
+}
+
+
+inline bool NearlyEqual(const float a, const float b, const float tolerance) {
+ return std::abs(a - b) < tolerance;
+}
+
+
+inline bool NearlyEqual(const float a, const float b) {
+ return NearlyEqual(a, b, EPSILON);
+}
+
+
+template<typename T>
+inline static int Round(const float a) {
+ return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a));
+}
+
+
+template<typename T>
+inline static void Swap(T* const a, T* const b) {
+ // Cache out the VALUE of what's at a.
+ T tmp = *a;
+ *a = *b;
+
+ *b = tmp;
+}
+
+
+static inline float randf() {
+ return rand() / static_cast<float>(RAND_MAX);
+}
+
+static inline float randf(const float min_value, const float max_value) {
+ return randf() * (max_value - min_value) + min_value;
+}
+
+static inline uint16 RealToFixed115(const float real_number) {
+ SCHECK(InRange(real_number, 0.0f, 2048.0f),
+ "Value out of range! %.2f", real_number);
+
+ static const float kMult = 32.0f;
+ const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
+ return static_cast<uint16>(real_number * kMult + round_add);
+}
+
+static inline float FixedToFloat115(const uint16 fp_number) {
+ const float kDiv = 32.0f;
+ return (static_cast<float>(fp_number) / kDiv);
+}
+
+static inline int RealToFixed1616(const float real_number) {
+ static const float kMult = 65536.0f;
+ SCHECK(InRange(real_number, -kMult, kMult),
+ "Value out of range! %.2f", real_number);
+
+ const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
+ return static_cast<int>(real_number * kMult + round_add);
+}
+
+static inline float FixedToFloat1616(const int fp_number) {
+ const float kDiv = 65536.0f;
+ return (static_cast<float>(fp_number) / kDiv);
+}
+
+template<typename T>
+// produces numbers in range [0,2*M_PI] (rather than -PI,PI)
+inline T FastAtan2(const T y, const T x) {
+ static const T coeff_1 = (T)(M_PI / 4.0);
+ static const T coeff_2 = (T)(3.0 * coeff_1);
+ const T abs_y = fabs(y);
+ T angle;
+ if (x >= 0) {
+ T r = (x - abs_y) / (x + abs_y);
+ angle = coeff_1 - coeff_1 * r;
+ } else {
+ T r = (x + abs_y) / (abs_y - x);
+ angle = coeff_2 - coeff_1 * r;
+ }
+ static const T PI_2 = 2.0 * M_PI;
+ return y < 0 ? PI_2 - angle : angle;
+}
+
+#define NELEMS(X) (sizeof(X) / sizeof(X[0]))
+
+namespace tf_tracking {
+
+#ifdef __ARM_NEON
+float ComputeMeanNeon(const float* const values, const int num_vals);
+
+float ComputeStdDevNeon(const float* const values, const int num_vals,
+ const float mean);
+
+float ComputeWeightedMeanNeon(const float* const values,
+ const float* const weights, const int num_vals);
+
+float ComputeCrossCorrelationNeon(const float* const values1,
+ const float* const values2,
+ const int num_vals);
+#endif
+
+inline float ComputeMeanCpu(const float* const values, const int num_vals) {
+ // Get mean.
+ float sum = values[0];
+ for (int i = 1; i < num_vals; ++i) {
+ sum += values[i];
+ }
+ return sum / static_cast<float>(num_vals);
+}
+
+
+inline float ComputeMean(const float* const values, const int num_vals) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) :
+#endif
+ ComputeMeanCpu(values, num_vals);
+}
+
+
+inline float ComputeStdDevCpu(const float* const values,
+ const int num_vals,
+ const float mean) {
+ // Get Std dev.
+ float squared_sum = 0.0f;
+ for (int i = 0; i < num_vals; ++i) {
+ squared_sum += Square(values[i] - mean);
+ }
+ return sqrt(squared_sum / static_cast<float>(num_vals));
+}
+
+
+inline float ComputeStdDev(const float* const values,
+ const int num_vals,
+ const float mean) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) :
+#endif
+ ComputeStdDevCpu(values, num_vals, mean);
+}
+
+
+// TODO(andrewharp): Accelerate with NEON.
+inline float ComputeWeightedMean(const float* const values,
+ const float* const weights,
+ const int num_vals) {
+ float sum = 0.0f;
+ float total_weight = 0.0f;
+ for (int i = 0; i < num_vals; ++i) {
+ sum += values[i] * weights[i];
+ total_weight += weights[i];
+ }
+ return sum / num_vals;
+}
+
+
+inline float ComputeCrossCorrelationCpu(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ float sxy = 0.0f;
+ for (int offset = 0; offset < num_vals; ++offset) {
+ sxy += values1[offset] * values2[offset];
+ }
+
+ const float cross_correlation = sxy / num_vals;
+
+ return cross_correlation;
+}
+
+
+inline float ComputeCrossCorrelation(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals)
+ :
+#endif
+ ComputeCrossCorrelationCpu(values1, values2, num_vals);
+}
+
+
+inline void NormalizeNumbers(float* const values, const int num_vals) {
+ // Find the mean and then subtract so that the new mean is 0.0.
+ const float mean = ComputeMean(values, num_vals);
+ VLOG(2) << "Mean is " << mean;
+ float* curr_data = values;
+ for (int i = 0; i < num_vals; ++i) {
+ *curr_data -= mean;
+ curr_data++;
+ }
+
+ // Now divide by the std deviation so the new standard deviation is 1.0.
+ // The numbers might all be identical (and thus shifted to 0.0 now),
+ // so only scale by the standard deviation if this is not the case.
+ const float std_dev = ComputeStdDev(values, num_vals, 0.0f);
+ if (std_dev > 0.0f) {
+ VLOG(2) << "Std dev is " << std_dev;
+ curr_data = values;
+ for (int i = 0; i < num_vals; ++i) {
+ *curr_data /= std_dev;
+ curr_data++;
+ }
+ }
+}
+
+
+// Returns the determinant of a 2x2 matrix.
+template<class T>
+inline T FindDeterminant2x2(const T* const a) {
+ // Determinant: (ad - bc)
+ return a[0] * a[3] - a[1] * a[2];
+}
+
+
+// Finds the inverse of a 2x2 matrix.
+// Returns true upon success, false if the matrix is not invertible.
+template<class T>
+inline bool Invert2x2(const T* const a, float* const a_inv) {
+ const float det = static_cast<float>(FindDeterminant2x2(a));
+ if (fabs(det) < EPSILON) {
+ return false;
+ }
+ const float inv_det = 1.0f / det;
+
+ a_inv[0] = inv_det * static_cast<float>(a[3]); // d
+ a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b
+ a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c
+ a_inv[3] = inv_det * static_cast<float>(a[0]); // a
+
+ return true;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc
new file mode 100755
index 0000000000..5a5250e32e
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc
@@ -0,0 +1,151 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NEON implementations of Image methods for compatible devices. Control
+// should never enter this compilation unit on incompatible devices.
+
+#ifdef __ARM_NEON
+
+#include <arm_neon.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+inline static float GetSum(const float32x4_t& values) {
+ static float32_t summed_values[4];
+ vst1q_f32(summed_values, values);
+ return summed_values[0]
+ + summed_values[1]
+ + summed_values[2]
+ + summed_values[3];
+}
+
+
+float ComputeMeanNeon(const float* const values, const int num_vals) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals = (const float32_t* const) values;
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset]));
+ }
+
+ // Pull the accumulated values into a single variable.
+ float sum = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ sum += values[offset];
+ }
+
+ const float mean_neon = sum / static_cast<float>(num_vals);
+
+#ifdef SANITY_CHECKS
+ const float mean_cpu = ComputeMeanCpu(values, num_vals);
+ SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals),
+ "Neon mismatch with CPU mean! %.10f vs %.10f",
+ mean_neon, mean_cpu);
+#endif
+
+ return mean_neon;
+}
+
+
+float ComputeStdDevNeon(const float* const values,
+ const int num_vals, const float mean) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals = (const float32_t* const) values;
+ const float32x4_t mean_vec = vdupq_n_f32(-mean);
+
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ const float32x4_t deltas =
+ vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset]));
+
+ accum = vmlaq_f32(accum, deltas, deltas);
+ }
+
+ // Pull the accumulated values into a single variable.
+ float squared_sum = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ squared_sum += Square(values[offset] - mean);
+ }
+
+ const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals));
+
+#ifdef SANITY_CHECKS
+ const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean);
+ SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals),
+ "Neon mismatch with CPU std dev! %.10f vs %.10f",
+ std_dev_neon, std_dev_cpu);
+#endif
+
+ return std_dev_neon;
+}
+
+
+float ComputeCrossCorrelationNeon(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals1 = (const float32_t* const) values1;
+ const float32_t* const arm_vals2 = (const float32_t* const) values2;
+
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ accum = vmlaq_f32(accum,
+ vld1q_f32(&arm_vals1[offset]),
+ vld1q_f32(&arm_vals2[offset]));
+ }
+
+ // Pull the accumulated values into a single variable.
+ float sxy = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ sxy += values1[offset] * values2[offset];
+ }
+
+ const float cross_correlation_neon = sxy / num_vals;
+
+#ifdef SANITY_CHECKS
+ const float cross_correlation_cpu =
+ ComputeCrossCorrelationCpu(values1, values2, num_vals);
+ SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu,
+ EPSILON * num_vals),
+ "Neon mismatch with CPU cross correlation! %.10f vs %.10f",
+ cross_correlation_neon, cross_correlation_cpu);
+#endif
+
+ return cross_correlation_neon;
+}
+
+} // namespace tf_tracking
+
+#endif // __ARM_NEON
diff --git a/tensorflow/examples/android/proto/box_coder.proto b/tensorflow/examples/android/proto/box_coder.proto
new file mode 100644
index 0000000000..8576294110
--- /dev/null
+++ b/tensorflow/examples/android/proto/box_coder.proto
@@ -0,0 +1,42 @@
+syntax = "proto2";
+
+package org_tensorflow_demo;
+
+// Prior for a single feature (like minimum x coordinate, width, area, etc.)
+message BoxCoderPrior {
+ optional float mean = 1 [default = 0.0];
+ optional float stddev = 2 [default = 1.0];
+};
+
+// Box encoding/decoding configuration for a single box.
+message BoxCoderOptions {
+ // Number of priors must match the number of values used to encoded
+ // values which is derived from the use_... flags below.
+ repeated BoxCoderPrior priors = 1;
+
+ // Minimum/maximum X/Y of the four corners are used as features.
+ // Order: MinX, MinY, MaxX, MaxY.
+ // Number of values: 4.
+ optional bool use_corners = 2 [default = true];
+
+ // Width and height of the box in this order.
+ // Number of values: 2.
+ optional bool use_width_height = 3 [default = false];
+
+ // Coordinates of the center of the box.
+ // Order: X, Y.
+ // Number of values: 2.
+ optional bool use_center = 4 [default = false];
+
+ // Area of the box.
+ // Number of values: 1.
+ optional bool use_area = 5 [default = false];
+};
+
+// Options for MultiBoxCoder which is a encoder/decoder for a fixed number of
+// boxes.
+// A list of BoxCoderOptions that allows for storing multiple box coder options
+// in a single file.
+message MultiBoxCoderOptions {
+ repeated BoxCoderOptions box_coder = 1;
+};
diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml
new file mode 100644
index 0000000000..674f25785a
--- /dev/null
+++ b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <org.tensorflow.demo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"/>
+
+ <org.tensorflow.demo.OverlayView
+ android:id="@+id/overlay"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"/>
+
+</FrameLayout>
diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/examples/android/res/values/base-strings.xml
index 93cfe0dac2..f6c57d5030 100644
--- a/tensorflow/examples/android/res/values/base-strings.xml
+++ b/tensorflow/examples/android/res/values/base-strings.xml
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
- Copyright 2013 The TensorFlow Authors. All Rights Reserved.
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,5 +17,6 @@
<resources>
<string name="app_name">TensorFlow Demo</string>
- <string name="activity_name_classification">TF Classification</string>
+ <string name="activity_name_classification">TF Classify</string>
+ <string name="activity_name_detection">TF Detect</string>
</resources>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
index e498c9e28f..2f16ded6c2 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
@@ -17,7 +17,6 @@ package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.RectF;
-
import java.util.List;
/**
@@ -44,10 +43,8 @@ public interface Classifier {
*/
private final Float confidence;
- /**
- * Optional location within the source image for the location of the recognized object.
- */
- private final RectF location;
+ /** Optional location within the source image for the location of the recognized object. */
+ private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
@@ -73,6 +70,10 @@ public interface Classifier {
return new RectF(location);
}
+ public void setLocation(RectF location) {
+ this.location = location;
+ }
+
@Override
public String toString() {
String resultString = "";
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
new file mode 100644
index 0000000000..d75136485a
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
@@ -0,0 +1,317 @@
+/*
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tensorflow.demo;
+
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Style;
+import android.graphics.RectF;
+import android.media.Image;
+import android.media.Image.Plane;
+import android.media.ImageReader;
+import android.media.ImageReader.OnImageAvailableListener;
+import android.os.SystemClock;
+import android.os.Trace;
+import android.util.Size;
+import android.util.TypedValue;
+import android.view.Display;
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Vector;
+import org.tensorflow.demo.OverlayView.DrawCallback;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.tracking.MultiBoxTracker;
+
+/**
+ * An activity that uses a TensorFlowMultiboxDetector and ObjectTracker to detect and then track
+ * objects.
+ */
+public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
+ private static final Logger LOGGER = new Logger();
+
+ private static final int NUM_LOCATIONS = 784;
+ private static final int INPUT_SIZE = 224;
+ private static final int IMAGE_MEAN = 128;
+ private static final float IMAGE_STD = 128;
+ private static final String INPUT_NAME = "ResizeBilinear";
+ private static final String OUTPUT_NAMES = "output_locations/Reshape,output_scores/Reshape";
+
+ private static final String MODEL_FILE = "file:///android_asset/multibox_model.pb";
+ private static final String LOCATION_FILE = "file:///android_asset/multibox_location_priors.pb";
+
+ // Minimum detection confidence to track a detection.
+ private static final float MINIMUM_CONFIDENCE = 0.1f;
+
+ private static final boolean SAVE_PREVIEW_BITMAP = false;
+
+ private static final boolean MAINTAIN_ASPECT = false;
+
+ private static final float TEXT_SIZE_DIP = 18;
+
+ private Integer sensorOrientation;
+
+ private TensorFlowMultiBoxDetector detector;
+
+ private int previewWidth = 0;
+ private int previewHeight = 0;
+ private byte[][] yuvBytes;
+ private int[] rgbBytes = null;
+ private Bitmap rgbFrameBitmap = null;
+ private Bitmap croppedBitmap = null;
+
+ private boolean computing = false;
+
+ private long timestamp = 0;
+
+ private Matrix frameToCropTransform;
+ private Matrix cropToFrameTransform;
+
+ private Bitmap cropCopyBitmap;
+
+ private MultiBoxTracker tracker;
+
+ private byte[] luminance;
+
+ private BorderedText borderedText;
+
+ private long lastProcessingTimeMs;
+
+ @Override
+ public void onPreviewSizeChosen(final Size size, final int rotation) {
+ final float textSizePx =
+ TypedValue.applyDimension(
+ TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
+ borderedText = new BorderedText(textSizePx);
+
+ tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
+
+ detector = new TensorFlowMultiBoxDetector();
+ try {
+ detector.initializeTensorFlow(
+ getAssets(),
+ MODEL_FILE,
+ LOCATION_FILE,
+ NUM_LOCATIONS,
+ INPUT_SIZE,
+ IMAGE_MEAN,
+ IMAGE_STD,
+ INPUT_NAME,
+ OUTPUT_NAMES);
+ } catch (final IOException e) {
+ LOGGER.e(e, "Exception!");
+ }
+
+ previewWidth = size.getWidth();
+ previewHeight = size.getHeight();
+
+ final Display display = getWindowManager().getDefaultDisplay();
+ final int screenOrientation = display.getRotation();
+
+ LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
+
+ sensorOrientation = rotation + screenOrientation;
+
+ LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
+ rgbBytes = new int[previewWidth * previewHeight];
+ rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
+ croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
+
+ frameToCropTransform =
+ ImageUtils.getTransformationMatrix(
+ previewWidth, previewHeight,
+ INPUT_SIZE, INPUT_SIZE,
+ sensorOrientation, MAINTAIN_ASPECT);
+
+ cropToFrameTransform = new Matrix();
+ frameToCropTransform.invert(cropToFrameTransform);
+ yuvBytes = new byte[3][];
+
+ addCallback(
+ new DrawCallback() {
+ @Override
+ public void drawCallback(final Canvas canvas) {
+ final Bitmap copy = cropCopyBitmap;
+
+ tracker.draw(canvas);
+
+ if (!isDebug()) {
+ return;
+ }
+
+ tracker.drawDebug(canvas);
+
+ if (copy != null) {
+ final Matrix matrix = new Matrix();
+ final float scaleFactor = 2;
+ matrix.postScale(scaleFactor, scaleFactor);
+ matrix.postTranslate(
+ canvas.getWidth() - copy.getWidth() * scaleFactor,
+ canvas.getHeight() - copy.getHeight() * scaleFactor);
+ canvas.drawBitmap(copy, matrix, new Paint());
+
+ final Vector<String> lines = new Vector<String>();
+ lines.add("Frame: " + previewWidth + "x" + previewHeight);
+ lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
+ lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
+ lines.add("Rotation: " + sensorOrientation);
+ lines.add("Inference time: " + lastProcessingTimeMs + "ms");
+
+ int lineNum = 0;
+ for (final String line : lines) {
+ borderedText.drawText(
+ canvas,
+ 10,
+ canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum,
+ line);
+ ++lineNum;
+ }
+ }
+ }
+ });
+ }
+
+ @Override
+ public void onImageAvailable(final ImageReader reader) {
+ Image image = null;
+
+ ++timestamp;
+ final long currTimestamp = timestamp;
+
+ try {
+ image = reader.acquireLatestImage();
+
+ if (image == null) {
+ return;
+ }
+
+ Trace.beginSection("imageAvailable");
+
+ final Plane[] planes = image.getPlanes();
+ fillBytes(planes, yuvBytes);
+
+ tracker.onFrame(
+ previewWidth,
+ previewHeight,
+ planes[0].getRowStride(),
+ sensorOrientation,
+ yuvBytes[0],
+ timestamp);
+
+ requestRender();
+
+ // No mutex needed as this method is not reentrant.
+ if (computing) {
+ image.close();
+ return;
+ }
+ computing = true;
+
+ final int yRowStride = planes[0].getRowStride();
+ final int uvRowStride = planes[1].getRowStride();
+ final int uvPixelStride = planes[1].getPixelStride();
+ ImageUtils.convertYUV420ToARGB8888(
+ yuvBytes[0],
+ yuvBytes[1],
+ yuvBytes[2],
+ rgbBytes,
+ previewWidth,
+ previewHeight,
+ yRowStride,
+ uvRowStride,
+ uvPixelStride,
+ false);
+
+ image.close();
+ } catch (final Exception e) {
+ if (image != null) {
+ image.close();
+ }
+ LOGGER.e(e, "Exception!");
+ Trace.endSection();
+ return;
+ }
+
+ rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
+ final Canvas canvas = new Canvas(croppedBitmap);
+ canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
+
+ // For examining the actual TF input.
+ if (SAVE_PREVIEW_BITMAP) {
+ ImageUtils.saveBitmap(croppedBitmap);
+ }
+
+ if (luminance == null) {
+ luminance = new byte[yuvBytes[0].length];
+ }
+ System.arraycopy(yuvBytes[0], 0, luminance, 0, luminance.length);
+
+ runInBackground(
+ new Runnable() {
+ @Override
+ public void run() {
+ final long startTime = SystemClock.uptimeMillis();
+ final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
+ lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
+
+ cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+ final Canvas canvas = new Canvas(cropCopyBitmap);
+ final Paint paint = new Paint();
+ paint.setColor(Color.RED);
+ paint.setStyle(Style.STROKE);
+ paint.setStrokeWidth(2.0f);
+
+ final List<Classifier.Recognition> mappedRecognitions =
+ new LinkedList<Classifier.Recognition>();
+
+ for (final Classifier.Recognition result : results) {
+ final RectF location = result.getLocation();
+ if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) {
+ canvas.drawRect(location, paint);
+
+ cropToFrameTransform.mapRect(location);
+ result.setLocation(location);
+ mappedRecognitions.add(result);
+ }
+ }
+
+ tracker.trackResults(mappedRecognitions, luminance, currTimestamp);
+
+ requestRender();
+ computing = false;
+ }
+ });
+
+ Trace.endSection();
+ }
+
+ @Override
+ protected int getLayoutId() {
+ return R.layout.camera_connection_fragment_tracking;
+ }
+
+ @Override
+ protected int getDesiredPreviewFrameSize() {
+ return INPUT_SIZE;
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
new file mode 100644
index 0000000000..66e25304d3
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -0,0 +1,218 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo;
+
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.RectF;
+import android.os.Trace;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.PriorityQueue;
+import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * A detector for general purpose object detection as described in Scalable Object Detection using
+ * Deep Neural Networks (https://arxiv.org/abs/1312.2249).
+ */
+public class TensorFlowMultiBoxDetector implements Classifier {
+ private static final Logger LOGGER = new Logger();
+
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+
+ // Only return this many results with at least this confidence.
+ private static final int MAX_RESULTS = Integer.MAX_VALUE;
+
+ // Config values.
+ private String inputName;
+ private int inputSize;
+ private int imageMean;
+ private float imageStd;
+
+ // Pre-allocated buffers.
+ private int[] intValues;
+ private float[] floatValues;
+ private float[] outputLocations;
+ private float[] outputScores;
+ private String[] outputNames;
+ private int numLocations;
+
+ private TensorFlowInferenceInterface inferenceInterface;
+
+ private float[] boxPriors;
+
+ /**
+ * Initializes a native TensorFlow session for classifying images.
+ *
+ * @param assetManager The asset manager to be used to load assets.
+ * @param modelFilename The filepath of the model GraphDef protocol buffer.
+ * @param locationFilename The filepath of label file for classes.
+ * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
+ * @param imageMean The assumed mean of the image values.
+ * @param imageStd The assumed std of the image values.
+ * @param inputName The label of the image input node.
+ * @param outputName The label of the output node.
+ * @return The native return value, 0 indicating success.
+ * @throws IOException
+ */
+ public int initializeTensorFlow(
+ final AssetManager assetManager,
+ final String modelFilename,
+ final String locationFilename,
+ final int numLocations,
+ final int inputSize,
+ final int imageMean,
+ final float imageStd,
+ final String inputName,
+ final String outputName)
+ throws IOException {
+ this.inputName = inputName;
+ this.inputSize = inputSize;
+ this.imageMean = imageMean;
+ this.imageStd = imageStd;
+ this.numLocations = numLocations;
+
+ this.boxPriors = new float[numLocations * 8];
+
+ loadCoderOptions(assetManager, locationFilename, boxPriors);
+
+ // Pre-allocate buffers.
+ outputNames = outputName.split(",");
+ intValues = new int[inputSize * inputSize];
+ floatValues = new float[inputSize * inputSize * 3];
+ outputScores = new float[numLocations];
+ outputLocations = new float[numLocations * 4];
+
+ inferenceInterface = new TensorFlowInferenceInterface();
+
+ return inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
+ }
+
+ // Load BoxCoderOptions from native code.
+ private native void loadCoderOptions(
+ AssetManager assetManager, String locationFilename, float[] boxPriors);
+
+ private float[] decodeLocationsEncoding(final float[] locationEncoding) {
+ final float[] locations = new float[locationEncoding.length];
+ boolean nonZero = false;
+ for (int i = 0; i < numLocations; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ final float currEncoding = locationEncoding[4 * i + j];
+ nonZero = nonZero || currEncoding != 0.0f;
+
+ final float mean = boxPriors[i * 8 + j * 2];
+ final float stdDev = boxPriors[i * 8 + j * 2 + 1];
+ float currentLocation = currEncoding * stdDev + mean;
+ currentLocation = Math.max(currentLocation, 0.0f);
+ currentLocation = Math.min(currentLocation, 1.0f);
+ locations[4 * i + j] = currentLocation;
+ }
+ }
+
+ if (!nonZero) {
+ LOGGER.w("No non-zero encodings; check log for inference errors.");
+ }
+ return locations;
+ }
+
+ private float[] decodeScoresEncoding(final float[] scoresEncoding) {
+ final float[] scores = new float[scoresEncoding.length];
+ for (int i = 0; i < scoresEncoding.length; ++i) {
+ scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
+ }
+ return scores;
+ }
+
+ @Override
+ public List<Recognition> recognizeImage(final Bitmap bitmap) {
+ // Log this method so that it can be analyzed with systrace.
+ Trace.beginSection("recognizeImage");
+
+ Trace.beginSection("preprocessBitmap");
+ // Preprocess the image data from 0-255 int to normalized float based
+ // on the provided parameters.
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+
+ for (int i = 0; i < intValues.length; ++i) {
+ floatValues[i * 3 + 0] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
+ floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
+ floatValues[i * 3 + 2] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
+ }
+ Trace.endSection(); // preprocessBitmap
+
+ // Copy the input data into TensorFlow.
+ Trace.beginSection("fillNodeFloat");
+ inferenceInterface.fillNodeFloat(
+ inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
+ Trace.endSection();
+
+ // Run the inference call.
+ Trace.beginSection("runInference");
+ inferenceInterface.runInference(outputNames);
+ Trace.endSection();
+
+ // Copy the output Tensor back into the output array.
+ Trace.beginSection("readNodeFloat");
+ final float[] outputScoresEncoding = new float[numLocations];
+ final float[] outputLocationsEncoding = new float[numLocations * 4];
+ inferenceInterface.readNodeFloat(outputNames[0], outputLocationsEncoding);
+ inferenceInterface.readNodeFloat(outputNames[1], outputScoresEncoding);
+ Trace.endSection();
+
+ outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
+ outputScores = decodeScoresEncoding(outputScoresEncoding);
+
+ // Find the best detections.
+ final PriorityQueue<Recognition> pq =
+ new PriorityQueue<Recognition>(
+ 1,
+ new Comparator<Recognition>() {
+ @Override
+ public int compare(final Recognition lhs, final Recognition rhs) {
+ // Intentionally reversed to put high confidence at the head of the queue.
+ return Float.compare(rhs.getConfidence(), lhs.getConfidence());
+ }
+ });
+
+ // Scale them back to the input size.
+ for (int i = 0; i < outputScores.length; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[4 * i] * inputSize,
+ outputLocations[4 * i + 1] * inputSize,
+ outputLocations[4 * i + 2] * inputSize,
+ outputLocations[4 * i + 3] * inputSize);
+ pq.add(new Recognition("" + i, "" + i, outputScores[i], detection));
+ }
+
+ final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
+ for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
+ recognitions.add(pq.poll());
+ }
+ Trace.endSection(); // "recognizeImage"
+ return recognitions;
+ }
+
+ @Override
+ public void close() {
+ inferenceInterface.close();
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
new file mode 100644
index 0000000000..24e5cb57df
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
@@ -0,0 +1,381 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Cap;
+import android.graphics.Paint.Join;
+import android.graphics.Paint.Style;
+import android.graphics.RectF;
+import android.util.DisplayMetrics;
+import android.util.Pair;
+import android.util.TypedValue;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+
+import org.tensorflow.demo.Classifier.Recognition;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
+ * objects to new detections.
+ */
+public class MultiBoxTracker {
+ private final Logger logger = new Logger();
+
+ private static final float TEXT_SIZE_DIP = 18;
+
+ // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
+ // the lower scored box (new or old) will be removed.
+ private static final float MAX_OVERLAP = 0.35f;
+
+ private static final float MIN_SIZE = 16.0f;
+
+ // Allow replacement of the tracked box with new results if
+ // correlation has dropped below this level.
+ private static final float MARGINAL_CORRELATION = 0.75f;
+
+ // Consider object to be lost if correlation falls below this threshold.
+ private static final float MIN_CORRELATION = 0.3f;
+
+ private static final int[] COLORS = {
+ Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA
+ };
+
+ private final Queue<Integer> availableColors = new LinkedList<Integer>();
+
+ public ObjectTracker objectTracker;
+
+ final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
+
+ private static class TrackedRecognition {
+ ObjectTracker.TrackedObject trackedObject;
+ float detectionConfidence;
+ int color;
+ }
+
+ private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
+
+ private final Paint boxPaint = new Paint();
+
+ private final float textSizePx;
+ private final BorderedText borderedText;
+
+ private Matrix frameToCanvasMatrix;
+
+ private int frameWidth;
+ private int frameHeight;
+
+ private int sensorOrientation;
+
+ public MultiBoxTracker(final DisplayMetrics metrics) {
+ for (final int color : COLORS) {
+ availableColors.add(color);
+ }
+
+ boxPaint.setColor(Color.RED);
+ boxPaint.setStyle(Style.STROKE);
+ boxPaint.setStrokeWidth(12.0f);
+ boxPaint.setStrokeCap(Cap.ROUND);
+ boxPaint.setStrokeJoin(Join.ROUND);
+ boxPaint.setStrokeMiter(100);
+
+ textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics);
+ borderedText = new BorderedText(textSizePx);
+ }
+
+ private Matrix getFrameToCanvasMatrix() {
+ return frameToCanvasMatrix;
+ }
+
+ public synchronized void drawDebug(final Canvas canvas) {
+ final Paint textPaint = new Paint();
+ textPaint.setColor(Color.WHITE);
+ textPaint.setTextSize(60.0f);
+
+ final Paint boxPaint = new Paint();
+ boxPaint.setColor(Color.RED);
+ boxPaint.setAlpha(200);
+ boxPaint.setStyle(Style.STROKE);
+
+ for (final Pair<Float, RectF> detection : screenRects) {
+ final RectF rect = detection.second;
+ canvas.drawRect(rect, boxPaint);
+ canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
+ borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
+ }
+
+ if (objectTracker == null) {
+ return;
+ }
+
+ // Draw correlations.
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
+ borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
+ }
+ }
+
+ final Matrix matrix = getFrameToCanvasMatrix();
+ objectTracker.drawDebug(canvas, matrix);
+ }
+
+ public synchronized void trackResults(
+ final List<Recognition> results, final byte[] frame, final long timestamp) {
+ logger.i("Processing %d results from %d", results.size(), timestamp);
+ processResults(timestamp, results, frame);
+ }
+
+ public synchronized void draw(final Canvas canvas) {
+ if (objectTracker == null) {
+ return;
+ }
+
+ // TODO(andrewharp): This may not work for non-90 deg rotations.
+ final float multiplier =
+ Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth);
+ frameToCanvasMatrix =
+ ImageUtils.getTransformationMatrix(
+ frameWidth,
+ frameHeight,
+ (int) (multiplier * frameHeight),
+ (int) (multiplier * frameWidth),
+ sensorOrientation,
+ false);
+
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ boxPaint.setColor(recognition.color);
+
+ final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
+ canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
+
+ final String labelString = String.format("%.2f", recognition.detectionConfidence);
+ borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
+ }
+ }
+ }
+
+ public synchronized void onFrame(
+ final int w,
+ final int h,
+ final int rowStride,
+ final int sensorOrienation,
+ final byte[] frame,
+ final long timestamp) {
+ if (objectTracker == null) {
+ ObjectTracker.clearInstance();
+
+ logger.i("Initializing ObjectTracker: %dx%d", w, h);
+ objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
+ frameWidth = w;
+ frameHeight = h;
+ this.sensorOrientation = sensorOrienation;
+ }
+
+ objectTracker.nextFrame(frame, null, timestamp, null, true);
+
+ // Clean up any objects not worth tracking any more.
+ final LinkedList<TrackedRecognition> copyList =
+ new LinkedList<TrackedRecognition>(trackedObjects);
+ for (final TrackedRecognition recognition : copyList) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+ final float correlation = trackedObject.getCurrentCorrelation();
+ if (correlation < MIN_CORRELATION) {
+ logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
+ trackedObject.stopTracking();
+ trackedObjects.remove(recognition);
+
+ availableColors.add(recognition.color);
+ }
+ }
+ }
+
+ private void processResults(
+ final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
+ final List<Pair<Float, RectF>> rectsToTrack = new LinkedList<Pair<Float, RectF>>();
+
+ screenRects.clear();
+ final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
+
+ for (final Recognition result : results) {
+ if (result.getLocation() == null) {
+ continue;
+ }
+ final RectF detectionFrameRect = new RectF(result.getLocation());
+
+ final RectF detectionScreenRect = new RectF();
+ rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
+
+ logger.v(
+ "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
+
+ screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
+
+ if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
+ logger.w("Degenerate rectangle! " + detectionFrameRect);
+ continue;
+ }
+
+ rectsToTrack.add(new Pair<Float, RectF>(result.getConfidence(), detectionFrameRect));
+ }
+
+ if (rectsToTrack.isEmpty()) {
+ logger.v("Nothing to track, aborting.");
+ return;
+ }
+
+ if (objectTracker == null) {
+ logger.w("No ObjectTracker, can't track anything!");
+ return;
+ }
+
+ logger.i("%d rects to track", rectsToTrack.size());
+ for (final Pair<Float, RectF> potential : rectsToTrack) {
+ handleDetection(originalFrame, timestamp, potential);
+ }
+ }
+
+ private void handleDetection(
+ final byte[] frameCopy, final long timestamp, final Pair<Float, RectF> potential) {
+ final ObjectTracker.TrackedObject potentialObject =
+ objectTracker.trackObject(potential.second, timestamp, frameCopy);
+
+ final float potentialCorrelation = potentialObject.getCurrentCorrelation();
+ logger.v(
+ "Tracked object went from %s to %s with correlation %.2f",
+ potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
+
+ if (potentialCorrelation < MARGINAL_CORRELATION) {
+ logger.v("Correlation too low to begin tracking %s.", potentialObject);
+ potentialObject.stopTracking();
+ return;
+ }
+
+ final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
+
+ float maxIntersect = 0.0f;
+
+ // This is the current tracked object whose color we will take. If left null we'll take the
+ // first one from the color queue.
+ TrackedRecognition recogToReplace = null;
+
+ // Look for intersections that will be overridden by this object or an intersection that would
+ // prevent this one from being placed.
+ for (final TrackedRecognition trackedRecognition : trackedObjects) {
+ final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
+ final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
+ final RectF intersection = new RectF();
+ final boolean intersects = intersection.setIntersect(a, b);
+
+ final float intersectAmount =
+ intersection.width()
+ * intersection.height()
+ / Math.min(a.width() * a.height(), b.width() * b.height());
+
+ // If there is an intersection with this currently tracked box above the maximum overlap
+ // percentage allowed, either the new recognition needs to be dismissed or the old
+ // recognition needs to be removed and possibly replaced with the new one.
+ if (intersects && intersectAmount > MAX_OVERLAP) {
+ if (potential.first < trackedRecognition.detectionConfidence
+ && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
+ // If track for the existing object is still going strong and the detection score was
+ // good, reject this new object.
+ potentialObject.stopTracking();
+ return;
+ } else {
+ removeList.add(trackedRecognition);
+
+ // Let the previously tracked object with max intersection amount donate its color to
+ // the new object.
+ if (intersectAmount > maxIntersect) {
+ maxIntersect = intersectAmount;
+ recogToReplace = trackedRecognition;
+ }
+ }
+ }
+ }
+
+ // If we're already tracking the max object and no intersections were found to bump off,
+ // pick the worst current tracked object to remove, if it's also worse than this candidate
+ // object.
+ if (availableColors.isEmpty() && removeList.isEmpty()) {
+ for (final TrackedRecognition candidate : trackedObjects) {
+ if (candidate.detectionConfidence < potential.first) {
+ if (recogToReplace == null
+ || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
+ // Save it so that we use this color for the new object.
+ recogToReplace = candidate;
+ }
+ }
+ }
+ if (recogToReplace != null) {
+ logger.v("Found non-intersecting object to remove.");
+ removeList.add(recogToReplace);
+ } else {
+ logger.v("No non-intersecting object found to remove");
+ }
+ }
+
+ // Remove everything that got intersected.
+ for (final TrackedRecognition trackedRecognition : removeList) {
+ logger.v(
+ "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
+ trackedRecognition.trackedObject,
+ trackedRecognition.detectionConfidence,
+ trackedRecognition.trackedObject.getCurrentCorrelation());
+ trackedRecognition.trackedObject.stopTracking();
+ trackedObjects.remove(trackedRecognition);
+ if (trackedRecognition != recogToReplace) {
+ availableColors.add(trackedRecognition.color);
+ }
+ }
+
+ if (recogToReplace == null && availableColors.isEmpty()) {
+ logger.e("No room to track this object, aborting.");
+ potentialObject.stopTracking();
+ return;
+ }
+
+ // Finally safe to say we can track this object.
+ logger.v(
+ "Tracking object %s with detection confidence %.2f at position %s",
+ potentialObject, potential.first, potential.second);
+ final TrackedRecognition trackedRecognition = new TrackedRecognition();
+ trackedRecognition.detectionConfidence = potential.first;
+ trackedRecognition.trackedObject = potentialObject;
+
+ // Use the color from a replaced object before taking one from the color queue.
+ trackedRecognition.color =
+ recogToReplace != null ? recogToReplace.color : availableColors.poll();
+ trackedObjects.add(trackedRecognition);
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
new file mode 100644
index 0000000000..211d8077a3
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
@@ -0,0 +1,649 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.PointF;
+import android.graphics.RectF;
+import android.graphics.Typeface;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+import javax.microedition.khronos.opengles.GL10;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.env.Size;
+
+/**
+ * True object detector/tracker class that tracks objects across consecutive preview frames.
+ * It provides a simplified Java interface to the analogous native object defined by
+ * jni/client_vision/tracking/object_tracker.*.
+ *
+ * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
+ * be allocated by ObjectTracker.getInstance(). In addition, release() should be called
+ * as soon as the ObjectTracker is no longer needed, and before a new one is created.
+ *
+ * nextFrame() should be called as new frames become available, preferably as often as possible.
+ *
+ * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
+ * are associated with the ObjectTracker that created them, and are only valid while that
+ * ObjectTracker still exists.
+ */
+public class ObjectTracker {
+ private final Logger logger = new Logger();
+
+ private static final boolean DRAW_TEXT = false;
+
+ /**
+ * How many history points to keep track of and draw in the red history line.
+ */
+ private static final int MAX_DEBUG_HISTORY_SIZE = 30;
+
+ /**
+ * How many frames of optical flow deltas to record.
+ * TODO(andrewharp): Push this down to the native level so it can be polled
+ * efficiently into a an array for upload, instead of keeping a duplicate
+ * copy in Java.
+ */
+ private static final int MAX_FRAME_HISTORY_SIZE = 200;
+
+ private static final int DOWNSAMPLE_FACTOR = 2;
+
+ private final byte[] downsampledFrame;
+
+ protected static ObjectTracker instance;
+
+ private final Map<String, TrackedObject> trackedObjects;
+
+ private long lastTimestamp;
+
+ private FrameChange lastKeypoints;
+
+ private final Vector<PointF> debugHistory;
+
+ private final LinkedList<TimestampedDeltas> timestampedDeltas;
+
+ protected final int frameWidth;
+ protected final int frameHeight;
+ private final int rowStride;
+ protected final boolean alwaysTrack;
+
+ private static class TimestampedDeltas {
+ final long timestamp;
+ final byte[] deltas;
+
+ public TimestampedDeltas(final long timestamp, final byte[] deltas) {
+ this.timestamp = timestamp;
+ this.deltas = deltas;
+ }
+ }
+
+ /**
+ * A simple class that records keypoint information, which includes
+ * local location, score and type. This will be used in calculating
+ * FrameChange.
+ */
+ public static class Keypoint {
+ public final float x;
+ public final float y;
+ public final float score;
+ public final int type;
+
+ public Keypoint(final float x, final float y) {
+ this.x = x;
+ this.y = y;
+ this.score = 0;
+ this.type = -1;
+ }
+
+ public Keypoint(final float x, final float y, final float score, final int type) {
+ this.x = x;
+ this.y = y;
+ this.score = score;
+ this.type = type;
+ }
+
+ Keypoint delta(final Keypoint other) {
+ return new Keypoint(this.x - other.x, this.y - other.y);
+ }
+ }
+
+ /**
+ * A simple class that could calculate Keypoint delta.
+ * This class will be used in calculating frame translation delta
+ * for optical flow.
+ */
+ public static class PointChange {
+ public final Keypoint keypointA;
+ public final Keypoint keypointB;
+ Keypoint pointDelta;
+ private final boolean wasFound;
+
+ public PointChange(final float x1, final float y1,
+ final float x2, final float y2,
+ final float score, final int type,
+ final boolean wasFound) {
+ this.wasFound = wasFound;
+
+ keypointA = new Keypoint(x1, y1, score, type);
+ keypointB = new Keypoint(x2, y2);
+ }
+
+ public Keypoint getDelta() {
+ if (pointDelta == null) {
+ pointDelta = keypointB.delta(keypointA);
+ }
+ return pointDelta;
+ }
+ }
+
+ /** A class that records a timestamped frame translation delta for optical flow. */
+ public static class FrameChange {
+ public static final int KEYPOINT_STEP = 7;
+
+ public final Vector<PointChange> pointDeltas;
+
+ private final float minScore;
+ private final float maxScore;
+
+ public FrameChange(final float[] framePoints) {
+ float minScore = 100.0f;
+ float maxScore = -100.0f;
+
+ pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
+
+ for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
+ final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
+ final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
+
+ final boolean wasFound = framePoints[i + 2] > 0.0f;
+
+ final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
+ final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
+ final float score = framePoints[i + 5];
+ final int type = (int) framePoints[i + 6];
+
+ minScore = Math.min(minScore, score);
+ maxScore = Math.max(maxScore, score);
+
+ pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
+ }
+
+ this.minScore = minScore;
+ this.maxScore = maxScore;
+ }
+ }
+
+ public static synchronized ObjectTracker getInstance(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ if (instance == null) {
+ instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
+ instance.init();
+ } else {
+ throw new RuntimeException(
+ "Tried to create a new objectracker before releasing the old one!");
+ }
+ return instance;
+ }
+
+ public static synchronized void clearInstance() {
+ if (instance != null) {
+ instance.release();
+ }
+ }
+
+ protected ObjectTracker(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ this.frameWidth = frameWidth;
+ this.frameHeight = frameHeight;
+ this.rowStride = rowStride;
+ this.alwaysTrack = alwaysTrack;
+ this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
+
+ trackedObjects = new HashMap<String, TrackedObject>();
+
+ debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
+
+ downsampledFrame =
+ new byte
+ [(frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR
+ * (frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR];
+ }
+
+ protected void init() {
+ // The native tracker never sees the full frame, so pre-scale dimensions
+ // by the downsample factor.
+ initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
+ }
+
+ private final float[] matrixValues = new float[9];
+
+ private long downsampledTimestamp;
+
+ @SuppressWarnings("unused")
+ public synchronized void drawOverlay(final GL10 gl,
+ final Size cameraViewSize, final Matrix matrix) {
+ final Matrix tempMatrix = new Matrix(matrix);
+ tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
+ tempMatrix.getValues(matrixValues);
+ drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
+ }
+
+ public synchronized void nextFrame(
+ final byte[] frameData, final byte[] uvData,
+ final long timestamp, final float[] transformationMatrix,
+ final boolean updateDebugInfo) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+
+ // Do Lucas Kanade using the fullframe initializer.
+ nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
+
+ timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
+ while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
+ timestampedDeltas.removeFirst();
+ }
+
+ for (final TrackedObject trackedObject : trackedObjects.values()) {
+ trackedObject.updateTrackedPosition();
+ }
+
+ if (updateDebugInfo) {
+ updateDebugHistory();
+ }
+
+ lastTimestamp = timestamp;
+ }
+
+ public synchronized void release() {
+ releaseMemoryNative();
+ synchronized (ObjectTracker.class) {
+ instance = null;
+ }
+ }
+
+ private void drawHistoryDebug(final Canvas canvas) {
+ drawHistoryPoint(
+ canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
+ }
+
+ private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
+ final Paint p = new Paint();
+ p.setAntiAlias(false);
+ p.setTypeface(Typeface.SERIF);
+
+ p.setColor(Color.RED);
+ p.setStrokeWidth(2.0f);
+
+ // Draw the center circle.
+ p.setColor(Color.GREEN);
+ canvas.drawCircle(startX, startY, 3.0f, p);
+
+ p.setColor(Color.RED);
+
+ // Iterate through in backwards order.
+ synchronized (debugHistory) {
+ final int numPoints = debugHistory.size();
+ float lastX = startX;
+ float lastY = startY;
+ for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
+ final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
+ final float newX = lastX + delta.x;
+ final float newY = lastY + delta.y;
+ canvas.drawLine(lastX, lastY, newX, newY, p);
+ lastX = newX;
+ lastY = newY;
+ }
+ }
+ }
+
+ private static int floatToChar(final float value) {
+ return Math.max(0, Math.min((int) (value * 255.999f), 255));
+ }
+
+ private void drawKeypointsDebug(final Canvas canvas) {
+ final Paint p = new Paint();
+ if (lastKeypoints == null) {
+ return;
+ }
+ final int keypointSize = 3;
+
+ final float minScore = lastKeypoints.minScore;
+ final float maxScore = lastKeypoints.maxScore;
+
+ for (final PointChange keypoint : lastKeypoints.pointDeltas) {
+ if (keypoint.wasFound) {
+ final int r =
+ floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
+ final int b =
+ floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
+
+ final int color = 0xFF000000 | (r << 16) | b;
+ p.setColor(color);
+
+ final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
+ keypoint.keypointB.x, keypoint.keypointB.y};
+ canvas.drawRect(screenPoints[2] - keypointSize,
+ screenPoints[3] - keypointSize,
+ screenPoints[2] + keypointSize,
+ screenPoints[3] + keypointSize, p);
+ p.setColor(Color.CYAN);
+ canvas.drawLine(screenPoints[2], screenPoints[3],
+ screenPoints[0], screenPoints[1], p);
+
+ if (DRAW_TEXT) {
+ p.setColor(Color.WHITE);
+ canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
+ keypoint.keypointA.x, keypoint.keypointA.y, p);
+ }
+ } else {
+ p.setColor(Color.YELLOW);
+ final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
+ canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
+ }
+ }
+ }
+
+ private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
+ final float positionY, final float radius) {
+ final RectF currPosition = getCurrentPosition(timestamp,
+ new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
+ return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
+ }
+
+ private synchronized RectF getCurrentPosition(final long timestamp, final RectF
+ oldPosition) {
+ final RectF downscaledFrameRect = downscaleRect(oldPosition);
+
+ final float[] delta = new float[4];
+ getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
+ downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
+
+ final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ return upscaleRect(newPosition);
+ }
+
+ private void updateDebugHistory() {
+ lastKeypoints = new FrameChange(getKeypointsNative(false));
+
+ if (lastTimestamp == 0) {
+ return;
+ }
+
+ final PointF delta =
+ getAccumulatedDelta(
+ lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
+
+ synchronized (debugHistory) {
+ debugHistory.add(delta);
+
+ while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
+ debugHistory.remove(0);
+ }
+ }
+ }
+
+ public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
+ canvas.save();
+ canvas.setMatrix(frameToCanvas);
+
+ drawHistoryDebug(canvas);
+ drawKeypointsDebug(canvas);
+
+ canvas.restore();
+ }
+
+ public Vector<String> getDebugText() {
+ final Vector<String> lines = new Vector<String>();
+
+ if (lastKeypoints != null) {
+ lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
+ lines.add("Min score: " + lastKeypoints.minScore);
+ lines.add("Max score: " + lastKeypoints.maxScore);
+ }
+
+ return lines;
+ }
+
+ public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
+ final List<byte[]> frameDeltas = new ArrayList<byte[]>();
+ while (timestampedDeltas.size() > 0) {
+ final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
+ if (currentDeltas.timestamp <= endFrameTime) {
+ frameDeltas.add(currentDeltas.deltas);
+ timestampedDeltas.removeFirst();
+ } else {
+ break;
+ }
+ }
+
+ return frameDeltas;
+ }
+
+ private RectF downscaleRect(final RectF fullFrameRect) {
+ return new RectF(
+ fullFrameRect.left / DOWNSAMPLE_FACTOR,
+ fullFrameRect.top / DOWNSAMPLE_FACTOR,
+ fullFrameRect.right / DOWNSAMPLE_FACTOR,
+ fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
+ }
+
+ private RectF upscaleRect(final RectF downsampledFrameRect) {
+ return new RectF(
+ downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
+ }
+
+ /**
+ * A TrackedObject represents a native TrackedObject, and provides access to the
+ * relevant native tracking information available after every frame update. They may
+ * be safely passed around and acessed externally, but will become invalid after
+ * stopTracking() is called or the related creating ObjectTracker is deactivated.
+ *
+ * @author andrewharp@google.com (Andrew Harp)
+ */
+ public class TrackedObject {
+ private final String id;
+
+ private long lastExternalPositionTime;
+
+ private RectF lastTrackedPosition;
+ private boolean visibleInLastFrame;
+
+ private boolean isDead;
+
+ TrackedObject(final RectF position, final long timestamp, final byte[] data) {
+ isDead = false;
+
+ id = Integer.toString(this.hashCode());
+
+ lastExternalPositionTime = timestamp;
+
+ synchronized (ObjectTracker.this) {
+ registerInitialAppearance(position, data);
+ setPreviousPosition(position, timestamp);
+ trackedObjects.put(id, this);
+ }
+ }
+
+ public void stopTracking() {
+ checkValidObject();
+
+ synchronized (ObjectTracker.this) {
+ isDead = true;
+ forgetNative(id);
+ trackedObjects.remove(id);
+ }
+ }
+
+ public float getCurrentCorrelation() {
+ checkValidObject();
+ return ObjectTracker.this.getCurrentCorrelation(id);
+ }
+
+ void registerInitialAppearance(final RectF position, final byte[] data) {
+ final RectF externalPosition = downscaleRect(position);
+ registerNewObjectWithAppearanceNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ data);
+ }
+
+ synchronized void setPreviousPosition(final RectF position, final long timestamp) {
+ checkValidObject();
+ synchronized (ObjectTracker.this) {
+ if (lastExternalPositionTime > timestamp) {
+ logger.w("Tried to use older position time!");
+ return;
+ }
+ final RectF externalPosition = downscaleRect(position);
+ lastExternalPositionTime = timestamp;
+
+ setPreviousPositionNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ lastExternalPositionTime);
+
+ updateTrackedPosition();
+ }
+ }
+
+ void setCurrentPosition(final RectF position) {
+ checkValidObject();
+ final RectF downsampledPosition = downscaleRect(position);
+ synchronized (ObjectTracker.this) {
+ setCurrentPositionNative(id,
+ downsampledPosition.left, downsampledPosition.top,
+ downsampledPosition.right, downsampledPosition.bottom);
+ }
+ }
+
+ private synchronized void updateTrackedPosition() {
+ checkValidObject();
+
+ final float[] delta = new float[4];
+ getTrackedPositionNative(id, delta);
+ lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ visibleInLastFrame = isObjectVisible(id);
+ }
+
+ public synchronized RectF getTrackedPositionInPreviewFrame() {
+ checkValidObject();
+
+ if (lastTrackedPosition == null) {
+ return null;
+ }
+ return upscaleRect(lastTrackedPosition);
+ }
+
+ synchronized long getLastExternalPositionTime() {
+ return lastExternalPositionTime;
+ }
+
+ public synchronized boolean visibleInLastPreviewFrame() {
+ return visibleInLastFrame;
+ }
+
+ private void checkValidObject() {
+ if (isDead) {
+ throw new RuntimeException("TrackedObject already removed from tracking!");
+ } else if (ObjectTracker.this != instance) {
+ throw new RuntimeException("TrackedObject created with another ObjectTracker!");
+ }
+ }
+ }
+
+ public synchronized TrackedObject trackObject(
+ final RectF position, final long timestamp, final byte[] frameData) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+ return new TrackedObject(position, timestamp, downsampledFrame);
+ }
+
+ public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
+ return new TrackedObject(position, lastTimestamp, frameData);
+ }
+
+ /*********************** NATIVE CODE *************************************/
+
+ /**
+ * This will contain an opaque pointer to the native ObjectTracker
+ */
+ private int nativeObjectTracker;
+
+ private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
+
+ protected native void registerNewObjectWithAppearanceNative(
+ String objectId, float x1, float y1, float x2, float y2, byte[] data);
+
+ protected native void setPreviousPositionNative(
+ String objectId, float x1, float y1, float x2, float y2, long timestamp);
+
+ protected native void setCurrentPositionNative(
+ String objectId, float x1, float y1, float x2, float y2);
+
+ protected native void forgetNative(String key);
+
+ protected native String getModelIdNative(String key);
+
+ protected native boolean haveObject(String key);
+ protected native boolean isObjectVisible(String key);
+ protected native float getCurrentCorrelation(String key);
+
+ protected native float getMatchScore(String key);
+
+ protected native void getTrackedPositionNative(String key, float[] points);
+
+ protected native void nextFrameNative(
+ byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
+
+ protected native void releaseMemoryNative();
+
+ protected native void getCurrentPositionNative(long timestamp,
+ final float positionX1, final float positionY1,
+ final float positionX2, final float positionY2,
+ final float[] delta);
+
+ protected native byte[] getKeypointsPacked(float scaleFactor);
+
+ protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
+
+ protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
+
+ protected static native void downsampleImageNative(
+ int width, int height, int rowStride, byte[] input, int factor, byte[] output);
+
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+}