From a30b9926bd7d5276d6ff35af9428dee3e77b7dcb Mon Sep 17 00:00:00 2001 From: Andrew Harp Date: Mon, 28 Nov 2016 15:41:52 -0800 Subject: Adding person detection/tracking sample Activity to Android TensorFlow demo. Change: 140411483 --- WORKSPACE | 7 + tensorflow/examples/android/AndroidManifest.xml | 9 + tensorflow/examples/android/BUILD | 19 + tensorflow/examples/android/README.md | 39 +- tensorflow/examples/android/jni/box_coder_jni.cc | 92 +++ .../examples/android/jni/object_tracking/config.h | 300 +++++++++ .../android/jni/object_tracking/flow_cache.h | 306 +++++++++ .../android/jni/object_tracking/frame_pair.cc | 308 +++++++++ .../android/jni/object_tracking/frame_pair.h | 103 +++ .../examples/android/jni/object_tracking/geom.h | 319 ++++++++++ .../android/jni/object_tracking/gl_utils.h | 55 ++ .../android/jni/object_tracking/image-inl.h | 642 +++++++++++++++++++ .../examples/android/jni/object_tracking/image.h | 346 +++++++++++ .../android/jni/object_tracking/image_data.h | 270 ++++++++ .../android/jni/object_tracking/image_neon.cc | 270 ++++++++ .../android/jni/object_tracking/image_utils.h | 301 +++++++++ .../android/jni/object_tracking/integral_image.h | 187 ++++++ .../android/jni/object_tracking/jni_utils.h | 62 ++ .../android/jni/object_tracking/keypoint.h | 48 ++ .../jni/object_tracking/keypoint_detector.cc | 549 ++++++++++++++++ .../jni/object_tracking/keypoint_detector.h | 133 ++++ .../android/jni/object_tracking/log_streaming.h | 37 ++ .../android/jni/object_tracking/object_detector.cc | 27 + .../android/jni/object_tracking/object_detector.h | 232 +++++++ .../android/jni/object_tracking/object_model.h | 101 +++ .../android/jni/object_tracking/object_tracker.cc | 690 +++++++++++++++++++++ .../android/jni/object_tracking/object_tracker.h | 271 ++++++++ .../jni/object_tracking/object_tracker_jni.cc | 463 ++++++++++++++ .../android/jni/object_tracking/optical_flow.cc | 490 +++++++++++++++ .../android/jni/object_tracking/optical_flow.h | 111 ++++ .../examples/android/jni/object_tracking/sprite.h | 205 ++++++ .../android/jni/object_tracking/time_log.cc | 29 + .../android/jni/object_tracking/time_log.h | 138 +++++ .../android/jni/object_tracking/tracked_object.cc | 163 +++++ .../android/jni/object_tracking/tracked_object.h | 191 ++++++ .../examples/android/jni/object_tracking/utils.h | 386 ++++++++++++ .../android/jni/object_tracking/utils_neon.cc | 151 +++++ tensorflow/examples/android/proto/box_coder.proto | 42 ++ .../layout/camera_connection_fragment_tracking.xml | 30 + .../examples/android/res/values/base-strings.xml | 5 +- .../src/org/tensorflow/demo/Classifier.java | 11 +- .../src/org/tensorflow/demo/DetectorActivity.java | 317 ++++++++++ .../demo/TensorFlowMultiBoxDetector.java | 218 +++++++ .../tensorflow/demo/tracking/MultiBoxTracker.java | 381 ++++++++++++ .../tensorflow/demo/tracking/ObjectTracker.java | 649 +++++++++++++++++++ 45 files changed, 9684 insertions(+), 19 deletions(-) create mode 100644 tensorflow/examples/android/jni/box_coder_jni.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/config.h create mode 100644 tensorflow/examples/android/jni/object_tracking/flow_cache.h create mode 100644 tensorflow/examples/android/jni/object_tracking/frame_pair.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/frame_pair.h create mode 100644 tensorflow/examples/android/jni/object_tracking/geom.h create mode 100755 tensorflow/examples/android/jni/object_tracking/gl_utils.h create mode 100644 tensorflow/examples/android/jni/object_tracking/image-inl.h create mode 100644 tensorflow/examples/android/jni/object_tracking/image.h create mode 100644 tensorflow/examples/android/jni/object_tracking/image_data.h create mode 100644 tensorflow/examples/android/jni/object_tracking/image_neon.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/image_utils.h create mode 100755 tensorflow/examples/android/jni/object_tracking/integral_image.h create mode 100644 tensorflow/examples/android/jni/object_tracking/jni_utils.h create mode 100644 tensorflow/examples/android/jni/object_tracking/keypoint.h create mode 100644 tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/keypoint_detector.h create mode 100644 tensorflow/examples/android/jni/object_tracking/log_streaming.h create mode 100644 tensorflow/examples/android/jni/object_tracking/object_detector.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/object_detector.h create mode 100644 tensorflow/examples/android/jni/object_tracking/object_model.h create mode 100644 tensorflow/examples/android/jni/object_tracking/object_tracker.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/object_tracker.h create mode 100644 tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/optical_flow.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/optical_flow.h create mode 100755 tensorflow/examples/android/jni/object_tracking/sprite.h create mode 100644 tensorflow/examples/android/jni/object_tracking/time_log.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/time_log.h create mode 100644 tensorflow/examples/android/jni/object_tracking/tracked_object.cc create mode 100644 tensorflow/examples/android/jni/object_tracking/tracked_object.h create mode 100644 tensorflow/examples/android/jni/object_tracking/utils.h create mode 100755 tensorflow/examples/android/jni/object_tracking/utils_neon.cc create mode 100644 tensorflow/examples/android/proto/box_coder.proto create mode 100644 tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/WORKSPACE b/WORKSPACE index 30aba396b8..20c0285084 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -29,6 +29,13 @@ new_http_archive( sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364" ) +new_http_archive( + name = "mobile_multibox", + build_file = "models.BUILD", + url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip", + sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f" +) + # TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT new_http_archive( diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml index 0a48d3d50b..e388734564 100644 --- a/tensorflow/examples/android/AndroidManifest.xml +++ b/tensorflow/examples/android/AndroidManifest.xml @@ -41,6 +41,15 @@ + + + + + + + diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index beb8337702..088133fed6 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -35,6 +35,7 @@ cc_binary( "notap", ], deps = [ + ":demo_proto_lib_cc", "//tensorflow/contrib/android:android_tensorflow_inference_jni", "//tensorflow/core:android_tensorflow_lib", LINKER_SCRIPT, @@ -60,6 +61,7 @@ android_binary( assets = [ "//tensorflow/examples/android/assets:asset_files", "@inception5h//:model_files", + "@mobile_multibox//:model_files", ], assets_dir = "", custom_package = "org.tensorflow.demo", @@ -111,3 +113,20 @@ filegroup( ) exports_files(["AndroidManifest.xml"]) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) + +tf_proto_library( + name = "demo_proto_lib", + srcs = glob( + ["**/*.proto"], + ), + cc_api_version = 2, + visibility = ["//visibility:public"], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets go here (must be at the end). diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index b0465f7faa..b6556cdef4 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -1,11 +1,24 @@ # TensorFlow Android Camera Demo -This folder contains a simple camera-based demo application utilizing TensorFlow. +This folder contains an example application utilizing TensorFlow for Android +devices. ## Description -This demo uses a Google Inception model to classify camera frames in real-time, -displaying the top results in an overlay on the camera image. +The demos in this folder are designed to give straightforward samples of using +TensorFlow in mobile applications. Inference is done using the Java JNI API +exposed by `tensorflow/contrib/android`. + +Current samples: + +1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java): + Uses the [Google Inception](https://arxiv.org/abs/1409.4842) + model to classify camera frames in real-time, displaying the top results + in an overlay on the camera image. +2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java): + Demonstrates a model based on [Scalable Object Detection + using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to + localize and track people in the camera preview in real-time. ## To build/install/run @@ -19,9 +32,9 @@ installed on your system. 3. The Android SDK and build tools may be obtained from: https://developer.android.com/tools/revisions/build-tools.html -The Android entries in [`/WORKSPACE`](../../../WORKSPACE#L2-L13) must be -uncommented with the paths filled in appropriately depending on where you -installed the NDK and SDK. Otherwise an error such as: +The Android entries in [`/WORKSPACE`](../../../WORKSPACE#L2-L13) +must be uncommented with the paths filled in appropriately depending on where +you installed the NDK and SDK. Otherwise an error such as: "The external label '//external:android/sdk' is not bound to anything" will be reported. @@ -29,19 +42,21 @@ The TensorFlow `GraphDef` that contains the model definition and weights is not packaged in the repo because of its size. It will be downloaded automatically via a new_http_archive defined in WORKSPACE. -**Optional**: If you wish to place the model in your assets manually (E.g. for -non-Bazel builds), remove the -`inception_5` entry in `BUILD` and download the archive yourself to the -`assets` directory in the source tree: +**Optional**: If you wish to place the models in your assets manually (E.g. for +non-Bazel builds), remove the `inception_5` and `mobile_multibox` entries in +`BUILD` and download the archives yourself to the `assets` directory in the +source tree: ```bash $ curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip +$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip -o /tmp/mobile_multibox_v1.zip $ unzip /tmp/inception5h.zip -d tensorflow/examples/android/assets/ +$ unzip /tmp/mobile_multibox_v1.zip -d tensorflow/examples/android/assets/ ``` -The labels file describing the possible classification will also be in the -assets directory. +The associated label and box prior files for the models will also be extracted +into the assets directory. After editing your WORKSPACE file to update the SDK/NDK configuration, you may build the APK. Run this from your workspace root: diff --git a/tensorflow/examples/android/jni/box_coder_jni.cc b/tensorflow/examples/android/jni/box_coder_jni.cc new file mode 100644 index 0000000000..be85414fc1 --- /dev/null +++ b/tensorflow/examples/android/jni/box_coder_jni.cc @@ -0,0 +1,92 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file loads the box coder mappings. + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/android/jni/jni_utils.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/proto/box_coder.pb.h" + +#define TENSORFLOW_METHOD(METHOD_NAME) \ + Java_org_tensorflow_demo_TensorFlowMultiBoxDetector_##METHOD_NAME // NOLINT + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location, + jfloatArray priors); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location, + jfloatArray priors) { + AAssetManager* const asset_manager = + AAssetManager_fromJava(env, java_asset_manager); + LOG(INFO) << "Acquired AssetManager."; + + const std::string location_str = GetString(env, location); + + org_tensorflow_demo::MultiBoxCoderOptions multi_options; + + LOG(INFO) << "Reading file to proto: " << location_str; + ReadFileToProtoOrDie(asset_manager, location_str.c_str(), &multi_options); + + LOG(INFO) << "Read file. " << multi_options.box_coder_size() << " entries."; + + jboolean iCopied = JNI_FALSE; + jfloat* values = env->GetFloatArrayElements(priors, &iCopied); + + const int array_length = env->GetArrayLength(priors); + LOG(INFO) << "Array length: " << array_length + << " (/8 = " << (array_length / 8) << ")"; + CHECK_EQ(array_length % 8, 0); + + const int num_items = + std::min(array_length / 8, multi_options.box_coder_size()); + + for (int i = 0; i < num_items; ++i) { + const org_tensorflow_demo::BoxCoderOptions& options = + multi_options.box_coder(i); + + for (int j = 0; j < 4; ++j) { + const org_tensorflow_demo::BoxCoderPrior& prior = options.priors(j); + values[i * 8 + j * 2] = prior.mean(); + values[i * 8 + j * 2 + 1] = prior.stddev(); + } + } + env->ReleaseFloatArrayElements(priors, values, 0); + + LOG(INFO) << "Read " << num_items << " options"; +} diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h new file mode 100644 index 0000000000..86e9fc71b6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/config.h @@ -0,0 +1,300 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ + +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" + +namespace tf_tracking { + +// Arbitrary keypoint type ids for labeling the origin of tracked keypoints. +enum KeypointType { + KEYPOINT_TYPE_DEFAULT = 0, + KEYPOINT_TYPE_FAST = 1, + KEYPOINT_TYPE_INTEREST = 2 +}; + +// Struct that can be used to more richly store the results of a detection +// than a single number, while still maintaining comparability. +struct MatchScore { + explicit MatchScore(double val) : value(val) {} + MatchScore() { value = 0.0; } + + double value; + + MatchScore& operator+(const MatchScore& rhs) { + value += rhs.value; + return *this; + } + + friend std::ostream& operator<<(std::ostream& stream, + const MatchScore& detection) { + stream << detection.value; + return stream; + } +}; +inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value < cC2.value; +} +inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value > cC2.value; +} +inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value >= cC2.value; +} +inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value <= cC2.value; +} + +// Fixed seed used for all random number generators. +static const int kRandomNumberSeed = 11111; + +// TODO(andrewharp): Move as many of these settings as possible into a settings +// object which can be passed in from Java at runtime. + +// Whether or not to use ESM instead of LK flow. +static const bool kUseEsm = false; + +// This constant gets added to the diagonal of the Hessian +// before solving for translation in 2dof ESM. +// It ensures better behavior especially in the absence of +// strong texture. +static const int kEsmRegularizer = 20; + +// Do we want to brightness-normalize each keypoint patch when we compute +// its flow using ESM? +static const bool kDoBrightnessNormalize = true; + +// Whether or not to use fixed-point interpolated pixel lookups in optical flow. +#define USE_FIXED_POINT_FLOW 1 + +// Whether to normalize keypoint windows for intensity in LK optical flow. +// This is a define for now because it helps keep the code streamlined. +#define NORMALIZE 1 + +// Number of keypoints to store per frame. +static const int kMaxKeypoints = 76; + +// Keypoint detection. +static const int kMaxTempKeypoints = 1024; + +// Number of floats each keypoint takes up when exporting to an array. +static const int kKeypointStep = 7; + +// Number of frame deltas to keep around in the circular queue. +static const int kNumFrames = 512; + +// Number of iterations to do tracking on each keypoint at each pyramid level. +static const int kNumIterations = 3; + +// The number of bins (on a side) to divide each bin from the previous +// cache level into. Higher numbers will decrease performance by increasing +// cache misses, but mean that cache hits are more locally relevant. +static const int kCacheBranchFactor = 2; + +// Number of levels to put in the cache. +// Each level of the cache is a square grid of bins, length: +// branch_factor^(level - 1) on each side. +// +// This may be greater than kNumPyramidLevels. Setting it to 0 means no +// caching is enabled. +static const int kNumCacheLevels = 3; + +// The level at which the cache pyramid gets cut off and replaced by a matrix +// transform if such a matrix has been provided to the cache. +static const int kCacheCutoff = 1; + +static const int kNumPyramidLevels = 4; + +// The minimum number of keypoints needed in an object's area. +static const int kMaxKeypointsForObject = 16; + +// Minimum number of pyramid levels to use after getting cached value. +// This allows fine-scale adjustment from the cached value, which is taken +// from the center of the corresponding top cache level box. +// Can be [0, kNumPyramidLevels). +static const int kMinNumPyramidLevelsToUseForAdjustment = 1; + +// Window size to integrate over to find local image derivative. +static const int kFlowIntegrationWindowSize = 3; + +// Total area of integration windows. +static const int kFlowArraySize = + (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1); + +// Error that's considered good enough to early abort tracking. +static const float kTrackingAbortThreshold = 0.03f; + +// Maximum number of deviations a keypoint-correspondence delta can be from the +// weighted average before being thrown out for region-based queries. +static const float kNumDeviations = 2.0f; + +// The length of the allowed delta between the forward and the backward +// flow deltas in terms of the length of the forward flow vector. +static const float kMaxForwardBackwardErrorAllowed = 0.5f; + +// Threshold for pixels to be considered different. +static const int kFastDiffAmount = 10; + +// How far from edge of frame to stop looking for FAST keypoints. +static const int kFastBorderBuffer = 10; + +// Determines if non-detected arbitrary keypoints should be added to regions. +// This will help if no keypoints have been detected in the region yet. +static const bool kAddArbitraryKeypoints = true; + +// How many arbitrary keypoints to add along each axis as candidates for each +// region? +static const int kNumToAddAsCandidates = 1; + +// In terms of region dimensions, how closely can we place keypoints +// next to each other? +static const float kClosestPercent = 0.6f; + +// How many FAST qualifying pixels must be connected to a pixel for it to be +// considered a candidate keypoint for Harris filtering. +static const int kMinNumConnectedForFastKeypoint = 8; + +// Size of the window to integrate over for Harris filtering. +// Compare to kFlowIntegrationWindowSize. +static const int kHarrisWindowSize = 2; + + +// DETECTOR PARAMETERS + +// Before relocalizing, make sure the new proposed position is better than +// the existing position by a small amount to prevent thrashing. +static const MatchScore kMatchScoreBuffer(0.01f); + +// Minimum score a tracked object can have and still be considered a match. +// TODO(andrewharp): Make this a per detector thing. +static const MatchScore kMinimumMatchScore(0.5f); + +static const float kMinimumCorrelationForTracking = 0.4f; + +static const MatchScore kMatchScoreForImmediateTermination(0.0f); + +// Run the detector every N frames. +static const int kDetectEveryNFrames = 4; + +// How many features does each feature_set contain? +static const int kFeaturesPerFeatureSet = 10; + +// The number of FeatureSets managed by the object detector. +// More FeatureSets can increase recall at the cost of performance. +static const int kNumFeatureSets = 7; + +// How many FeatureSets must respond affirmatively for a candidate descriptor +// and position to be given more thorough attention? +static const int kNumFeatureSetsForCandidate = 2; + +// How large the thumbnails used for correlation validation are. Used for both +// width and height. +static const int kNormalizedThumbnailSize = 11; + +// The area of intersection divided by union for the bounding boxes that tells +// if this tracking has slipped enough to invalidate all unlocked examples. +static const float kPositionOverlapThreshold = 0.6f; + +// The number of detection failures allowed before an object goes invisible. +// Tracking will still occur, so if it is actually still being tracked and +// comes back into a detectable position, it's likely to be found. +static const int kMaxNumDetectionFailures = 4; + + +// Minimum square size to scan with sliding window. +static const float kScanMinSquareSize = 16.0f; + +// Minimum square size to scan with sliding window. +static const float kScanMaxSquareSize = 64.0f; + +// Scale difference for consecutive scans of the sliding window. +static const float kScanScaleFactor = sqrtf(2.0f); + +// Step size for sliding window. +static const int kScanStepSize = 10; + + +// How tightly to pack the descriptor boxes for confirmed exemplars. +static const float kLockedScaleFactor = 1 / sqrtf(2.0f); + +// How tightly to pack the descriptor boxes for unconfirmed exemplars. +static const float kUnlockedScaleFactor = 1 / 2.0f; + +// How tightly the boxes to scan centered at the last known position will be +// packed. +static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f); + +// The bounds on how close a new object example must be to existing object +// examples for detection to be valid. +static const float kMinCorrelationForNewExample = 0.75f; +static const float kMaxCorrelationForNewExample = 0.99f; + + +// The number of safe tries an exemplar has after being created before +// missed detections count against it. +static const int kFreeTries = 5; + +// A false positive is worth this many missed detections. +static const int kFalsePositivePenalty = 5; + +struct ObjectDetectorConfig { + const Size image_size; + + explicit ObjectDetectorConfig(const Size& image_size) + : image_size(image_size) {} + virtual ~ObjectDetectorConfig() = default; +}; + +struct KeypointDetectorConfig { + const Size image_size; + + bool detect_skin; + + explicit KeypointDetectorConfig(const Size& image_size) + : image_size(image_size), + detect_skin(false) {} +}; + + +struct OpticalFlowConfig { + const Size image_size; + + explicit OpticalFlowConfig(const Size& image_size) + : image_size(image_size) {} +}; + +struct TrackerConfig { + const Size image_size; + KeypointDetectorConfig keypoint_detector_config; + OpticalFlowConfig flow_config; + bool always_track; + + float object_box_scale_factor_for_features; + + explicit TrackerConfig(const Size& image_size) + : image_size(image_size), + keypoint_detector_config(image_size), + flow_config(image_size), + always_track(false), + object_box_scale_factor_for_features(1.0f) {} +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h new file mode 100644 index 0000000000..8813ab6d71 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h @@ -0,0 +1,306 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +// Class that helps OpticalFlow to speed up flow computation +// by caching coarse-grained flow. +class FlowCache { + public: + explicit FlowCache(const OpticalFlowConfig* const config) + : config_(config), + image_size_(config->image_size), + optical_flow_(config), + fullframe_matrix_(NULL) { + for (int i = 0; i < kNumCacheLevels; ++i) { + const int curr_dims = BlockDimForCacheLevel(i); + has_cache_[i] = new Image(curr_dims, curr_dims); + displacements_[i] = new Image(curr_dims, curr_dims); + } + } + + ~FlowCache() { + for (int i = 0; i < kNumCacheLevels; ++i) { + SAFE_DELETE(has_cache_[i]); + SAFE_DELETE(displacements_[i]); + } + delete[](fullframe_matrix_); + fullframe_matrix_ = NULL; + } + + void NextFrame(ImageData* const new_frame, + const float* const align_matrix23) { + ClearCache(); + SetFullframeAlignmentMatrix(align_matrix23); + optical_flow_.NextFrame(new_frame); + } + + void ClearCache() { + for (int i = 0; i < kNumCacheLevels; ++i) { + has_cache_[i]->Clear(false); + } + delete[](fullframe_matrix_); + fullframe_matrix_ = NULL; + } + + // Finds the flow at a point, using the cache for performance. + bool FindFlowAtPoint(const float u_x, const float u_y, + float* const flow_x, float* const flow_y) const { + // Get the best guess from the cache. + const Point2f guess_from_cache = LookupGuess(u_x, u_y); + + *flow_x = guess_from_cache.x; + *flow_y = guess_from_cache.y; + + // Now refine the guess using the image pyramid. + for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1; + pyramid_level >= 0; --pyramid_level) { + if (!optical_flow_.FindFlowAtPointSingleLevel( + pyramid_level, u_x, u_y, false, flow_x, flow_y)) { + return false; + } + } + + return true; + } + + // Determines the displacement of a point, and uses that to calculate a new + // position. + // Returns true iff the displacement determination worked and the new position + // is in the image. + bool FindNewPositionOfPoint(const float u_x, const float u_y, + float* final_x, float* final_y) const { + float flow_x; + float flow_y; + if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) { + return false; + } + + // Add in the displacement to get the final position. + *final_x = u_x + flow_x; + *final_y = u_y + flow_y; + + // Assign the best guess, if we're still in the image. + if (InRange(*final_x, 0.0f, static_cast(image_size_.width) - 1) && + InRange(*final_y, 0.0f, static_cast(image_size_.height) - 1)) { + return true; + } else { + return false; + } + } + + // Comparison function for qsort. + static int Compare(const void* a, const void* b) { + return *reinterpret_cast(a) - + *reinterpret_cast(b); + } + + // Returns the median flow within the given bounding box as determined + // by a grid_width x grid_height grid. + Point2f GetMedianFlow(const BoundingBox& bounding_box, + const bool filter_by_fb_error, + const int grid_width, + const int grid_height) const { + const int kMaxPoints = 100; + SCHECK(grid_width * grid_height <= kMaxPoints, + "Too many points for Median flow!"); + + const BoundingBox valid_box = bounding_box.Intersect( + BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1)); + + if (valid_box.GetArea() <= 0.0f) { + return Point2f(0, 0); + } + + float x_deltas[kMaxPoints]; + float y_deltas[kMaxPoints]; + + int curr_offset = 0; + for (int i = 0; i < grid_width; ++i) { + for (int j = 0; j < grid_height; ++j) { + const float x_in = valid_box.left_ + + (valid_box.GetWidth() * i) / (grid_width - 1); + + const float y_in = valid_box.top_ + + (valid_box.GetHeight() * j) / (grid_height - 1); + + float curr_flow_x; + float curr_flow_y; + const bool success = FindNewPositionOfPoint(x_in, y_in, + &curr_flow_x, &curr_flow_y); + + if (success) { + x_deltas[curr_offset] = curr_flow_x; + y_deltas[curr_offset] = curr_flow_y; + ++curr_offset; + } else { + LOGW("Tracking failure!"); + } + } + } + + if (curr_offset > 0) { + qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare); + qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare); + + return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]); + } + + LOGW("No points were valid!"); + return Point2f(0, 0); + } + + void SetFullframeAlignmentMatrix(const float* const align_matrix23) { + if (align_matrix23 != NULL) { + if (fullframe_matrix_ == NULL) { + fullframe_matrix_ = new float[6]; + } + + memcpy(fullframe_matrix_, align_matrix23, + 6 * sizeof(fullframe_matrix_[0])); + } + } + + private: + Point2f LookupGuessFromLevel( + const int cache_level, const float x, const float y) const { + // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level); + + // Cutoff at the target level and use the matrix transform instead. + if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) { + const float xnew = x * fullframe_matrix_[0] + + y * fullframe_matrix_[1] + + fullframe_matrix_[2]; + const float ynew = x * fullframe_matrix_[3] + + y * fullframe_matrix_[4] + + fullframe_matrix_[5]; + + return Point2f(xnew - x, ynew - y); + } + + const int level_dim = BlockDimForCacheLevel(cache_level); + const int pixels_per_cache_block_x = + (image_size_.width + level_dim - 1) / level_dim; + const int pixels_per_cache_block_y = + (image_size_.height + level_dim - 1) / level_dim; + const int index_x = x / pixels_per_cache_block_x; + const int index_y = y / pixels_per_cache_block_y; + + Point2f displacement; + if (!(*has_cache_[cache_level])[index_y][index_x]) { + (*has_cache_[cache_level])[index_y][index_x] = true; + + // Get the lower cache level's best guess, if it exists. + displacement = cache_level >= kNumCacheLevels - 1 ? + Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y); + // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level, + // best_guess.x, best_guess.y); + + // Find the center of the block. + const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x; + const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y; + const int pyramid_level = PyramidLevelForCacheLevel(cache_level); + + // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] " + // "Querying %5.2f, %5.2f at pyramid level %d, ", + // cache_level, index_x, index_y, + // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y, + // center_x, center_y, pyramid_level); + + // TODO(andrewharp): Turn on FB error filtering. + const bool success = optical_flow_.FindFlowAtPointSingleLevel( + pyramid_level, center_x, center_y, false, + &displacement.x, &displacement.y); + + if (!success) { + LOGV("Computation of cached value failed for level %d!", cache_level); + } + + // Store the value for later use. + (*displacements_[cache_level])[index_y][index_x] = displacement; + } else { + displacement = (*displacements_[cache_level])[index_y][index_x]; + } + + // LOGI("Returning %5.2f, %5.2f for level %d", + // displacement.x, displacement.y, cache_level); + return displacement; + } + + Point2f LookupGuess(const float x, const float y) const { + if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) { + return Point2f(0, 0); + } + + // LOGI("Looking up guess at %5.2f %5.2f.", x, y); + if (kNumCacheLevels > 0) { + return LookupGuessFromLevel(0, x, y); + } else { + return Point2f(0, 0); + } + } + + // Returns the number of cache bins in each dimension for a given level + // of the cache. + int BlockDimForCacheLevel(const int cache_level) const { + // The highest (coarsest) cache level has a block dim of kCacheBranchFactor, + // thus if there are 4 cache levels, requesting level 3 (0-based) should + // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2, + // and so on. + int block_dim = kNumCacheLevels; + for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level; + --curr_level) { + block_dim *= kCacheBranchFactor; + } + return block_dim; + } + + // Returns the level of the image pyramid that a given cache level maps to. + int PyramidLevelForCacheLevel(const int cache_level) const { + // Higher cache and pyramid levels have smaller dimensions. The highest + // cache level should refer to the highest image pyramid level. The + // lower, finer image pyramid levels are uncached (assuming + // kNumCacheLevels < kNumPyramidLevels). + return cache_level + (kNumPyramidLevels - kNumCacheLevels); + } + + const OpticalFlowConfig* const config_; + + const Size image_size_; + OpticalFlow optical_flow_; + + float* fullframe_matrix_; + + // Whether this value is currently present in the cache. + Image* has_cache_[kNumCacheLevels]; + + // The cached displacement values. + Image* displacements_[kNumCacheLevels]; + + TF_DISALLOW_COPY_AND_ASSIGN(FlowCache); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc new file mode 100644 index 0000000000..fa86e2363c --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc @@ -0,0 +1,308 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" + +namespace tf_tracking { + +void FramePair::Init(const int64 start_time, const int64 end_time) { + start_time_ = start_time; + end_time_ = end_time; + memset(optical_flow_found_keypoint_, false, + sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints); + number_of_keypoints_ = 0; +} + +void FramePair::AdjustBox(const BoundingBox box, + float* const translation_x, + float* const translation_y, + float* const scale_x, + float* const scale_y) const { + static float weights[kMaxKeypoints]; + static Point2f deltas[kMaxKeypoints]; + memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints); + + BoundingBox resized_box(box); + resized_box.Scale(0.4f, 0.4f); + FillWeights(resized_box, weights); + FillTranslations(deltas); + + const Point2f translation = GetWeightedMedian(weights, deltas); + + *translation_x = translation.x; + *translation_y = translation.y; + + const Point2f old_center = box.GetCenter(); + const int good_scale_points = + FillScales(old_center, translation, weights, deltas); + + // Default scale factor is 1 for x and y. + *scale_x = 1.0f; + *scale_y = 1.0f; + + // The assumption is that all deltas that make it to this stage with a + // correspondending optical_flow_found_keypoint_[i] == true are not in + // themselves degenerate. + // + // The degeneracy with scale arose because if the points are too close to the + // center of the objects, the scale ratio determination might be incalculable. + // + // The check for kMinNumInRange is not a degeneracy check, but merely an + // attempt to ensure some sort of stability. The actual degeneracy check is in + // the comparison to EPSILON in FillScales (which I've updated to return the + // number good remaining as well). + static const int kMinNumInRange = 5; + if (good_scale_points >= kMinNumInRange) { + const float scale_factor = GetWeightedMedianScale(weights, deltas); + + if (scale_factor > 0.0f) { + *scale_x = scale_factor; + *scale_y = scale_factor; + } + } +} + +int FramePair::FillWeights(const BoundingBox& box, + float* const weights) const { + // Compute the max score. + float max_score = -FLT_MAX; + float min_score = FLT_MAX; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (optical_flow_found_keypoint_[i]) { + max_score = MAX(max_score, frame1_keypoints_[i].score_); + min_score = MIN(min_score, frame1_keypoints_[i].score_); + } + } + + int num_in_range = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + weights[i] = 0.0f; + continue; + } + + const bool in_box = box.Contains(frame1_keypoints_[i].pos_); + if (in_box) { + ++num_in_range; + } + + // The weighting based off distance. Anything within the bounding box + // has a weight of 1, and everything outside of that is within the range + // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio. + float distance_score = 1.0f; + if (!in_box) { + const Point2f initial = box.GetCenter(); + const float sq_x_dist = + Square(initial.x - frame1_keypoints_[i].pos_.x); + const float sq_y_dist = + Square(initial.y - frame1_keypoints_[i].pos_.y); + const float squared_half_width = Square(box.GetWidth() / 2.0f); + const float squared_half_height = Square(box.GetHeight() / 2.0f); + + static const float kOutOfBoxMultiplier = 0.5f; + distance_score = kOutOfBoxMultiplier * + MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist); + } + + // The weighting based on relative score strength. kBaseScore - 1.0f. + float intrinsic_score = 1.0f; + if (max_score > min_score) { + static const float kBaseScore = 0.5f; + intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) / + (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore; + } + + // The final score will be in the range [0, 1]. + weights[i] = distance_score * intrinsic_score; + } + + return num_in_range; +} + +void FramePair::FillTranslations(Point2f* const translations) const { + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + translations[i].x = + frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x; + translations[i].y = + frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y; + } +} + +int FramePair::FillScales(const Point2f& old_center, + const Point2f& translation, + float* const weights, + Point2f* const scales) const { + int num_good = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + + const Keypoint keypoint1 = frame1_keypoints_[i]; + const Keypoint keypoint2 = frame2_keypoints_[i]; + + const float dist1_x = keypoint1.pos_.x - old_center.x; + const float dist1_y = keypoint1.pos_.y - old_center.y; + + const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x; + const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y; + + // Make sure that the scale makes sense; points too close to the center + // will result in either NaNs or infinite results for scale due to + // limited tracking and floating point resolution. + // Also check that the parity of the points is the same with respect to + // x and y, as we can't really make sense of data that has flipped. + if (((dist2_x > EPSILON && dist1_x > EPSILON) || + (dist2_x < -EPSILON && dist1_x < -EPSILON)) && + ((dist2_y > EPSILON && dist1_y > EPSILON) || + (dist2_y < -EPSILON && dist1_y < -EPSILON))) { + scales[i].x = dist2_x / dist1_x; + scales[i].y = dist2_y / dist1_y; + ++num_good; + } else { + weights[i] = 0.0f; + scales[i].x = 1.0f; + scales[i].y = 1.0f; + } + } + return num_good; +} + +struct WeightedDelta { + float weight; + float delta; +}; + +// Sort by delta, not by weight. +inline int WeightedDeltaCompare(const void* const a, const void* const b) { + return (reinterpret_cast(a)->delta - + reinterpret_cast(b)->delta) <= 0 ? 1 : -1; +} + +// Returns the median delta from a sorted set of weighted deltas. +static float GetMedian(const int num_items, + const WeightedDelta* const weighted_deltas, + const float sum) { + if (num_items == 0 || sum < EPSILON) { + return 0.0f; + } + + float current_weight = 0.0f; + const float target_weight = sum / 2.0f; + for (int i = 0; i < num_items; ++i) { + if (weighted_deltas[i].weight > 0.0f) { + current_weight += weighted_deltas[i].weight; + if (current_weight >= target_weight) { + return weighted_deltas[i].delta; + } + } + } + LOGW("Median not found! %d points, sum of %.2f", num_items, sum); + return 0.0f; +} + +Point2f FramePair::GetWeightedMedian( + const float* const weights, const Point2f* const deltas) const { + Point2f median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints]; + + // Compute median X value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + // Compute median Y value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + weighted_deltas[i].delta = deltas[i].y; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + return median_delta; +} + +float FramePair::GetWeightedMedianScale( + const float* const weights, const Point2f* const deltas) const { + float median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints * 2]; + + // Compute median scale value across x and y. + { + float total_weight = 0.0f; + + // Add X values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + // Add Y values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y; + const float weight = weights[i]; + weighted_deltas[i + kMaxKeypoints].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta), + WeightedDeltaCompare); + + median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight); + } + + return median_delta; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h new file mode 100644 index 0000000000..3f2559a5e0 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h @@ -0,0 +1,103 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ + +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +namespace tf_tracking { + +// A class that records keypoint correspondences from pairs of +// consecutive frames. +class FramePair { + public: + FramePair() + : start_time_(0), + end_time_(0), + number_of_keypoints_(0) {} + + // Cleans up the FramePair so that they can be reused. + void Init(const int64 start_time, const int64 end_time); + + void AdjustBox(const BoundingBox box, + float* const translation_x, + float* const translation_y, + float* const scale_x, + float* const scale_y) const; + + private: + // Returns the weighted median of the given deltas, computed independently on + // x and y. Returns 0,0 in case of failure. The assumption is that a + // translation of 0.0 in the degenerate case is the best that can be done, and + // should not be considered an error. + // + // In the case of scale, a slight exception is made just to be safe and + // there is a check for 0.0 explicitly, but that shouldn't ever be possible to + // happen naturally because of the non-zero + parity checks in FillScales. + Point2f GetWeightedMedian(const float* const weights, + const Point2f* const deltas) const; + + float GetWeightedMedianScale(const float* const weights, + const Point2f* const deltas) const; + + // Weights points based on the query_point and cutoff_dist. + int FillWeights(const BoundingBox& box, + float* const weights) const; + + // Fills in the array of deltas with the translations of the points + // between frames. + void FillTranslations(Point2f* const translations) const; + + // Fills in the array of deltas with the relative scale factor of points + // relative to a given center. Has the ability to override the weight to 0 if + // a degenerate scale is detected. + // Translation is the amount the center of the box has moved from one frame to + // the next. + int FillScales(const Point2f& old_center, + const Point2f& translation, + float* const weights, + Point2f* const scales) const; + + // TODO(andrewharp): Make these private. + public: + // The time at frame1. + int64 start_time_; + + // The time at frame2. + int64 end_time_; + + // This array will contain the keypoints found in frame 1. + Keypoint frame1_keypoints_[kMaxKeypoints]; + + // Contain the locations of the keypoints from frame 1 in frame 2. + Keypoint frame2_keypoints_[kMaxKeypoints]; + + // The number of keypoints in frame 1. + int number_of_keypoints_; + + // Keeps track of which keypoint correspondences were actually found from one + // frame to another. + // The i-th element of this array will be non-zero if and only if the i-th + // keypoint of frame 1 was found in frame 2. + bool optical_flow_found_keypoint_[kMaxKeypoints]; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(FramePair); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h new file mode 100644 index 0000000000..5d5249cd97 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/geom.h @@ -0,0 +1,319 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ + +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +struct Size { + Size(const int width, const int height) : width(width), height(height) {} + + int width; + int height; +}; + + +class Point2f { + public: + Point2f() : x(0.0f), y(0.0f) {} + Point2f(const float x, const float y) : x(x), y(y) {} + + inline Point2f operator- (const Point2f& that) const { + return Point2f(this->x - that.x, this->y - that.y); + } + + inline Point2f operator+ (const Point2f& that) const { + return Point2f(this->x + that.x, this->y + that.y); + } + + inline Point2f& operator+= (const Point2f& that) { + this->x += that.x; + this->y += that.y; + return *this; + } + + inline Point2f& operator-= (const Point2f& that) { + this->x -= that.x; + this->y -= that.y; + return *this; + } + + inline Point2f operator- (const Point2f& that) { + return Point2f(this->x - that.x, this->y - that.y); + } + + inline float LengthSquared() { + return Square(this->x) + Square(this->y); + } + + inline float Length() { + return sqrtf(LengthSquared()); + } + + inline float DistanceSquared(const Point2f& that) { + return Square(this->x - that.x) + Square(this->y - that.y); + } + + inline float Distance(const Point2f& that) { + return sqrtf(DistanceSquared(that)); + } + + float x; + float y; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) { + stream << point.x << "," << point.y; + return stream; +} + +class BoundingBox { + public: + BoundingBox() + : left_(0), + top_(0), + right_(0), + bottom_(0) {} + + BoundingBox(const BoundingBox& bounding_box) + : left_(bounding_box.left_), + top_(bounding_box.top_), + right_(bounding_box.right_), + bottom_(bounding_box.bottom_) { + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); + } + + BoundingBox(const float left, + const float top, + const float right, + const float bottom) + : left_(left), + top_(top), + right_(right), + bottom_(bottom) { + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); + } + + BoundingBox(const Point2f& point1, const Point2f& point2) + : left_(MIN(point1.x, point2.x)), + top_(MIN(point1.y, point2.y)), + right_(MAX(point1.x, point2.x)), + bottom_(MAX(point1.y, point2.y)) {} + + inline void CopyToArray(float* const bounds_array) const { + bounds_array[0] = left_; + bounds_array[1] = top_; + bounds_array[2] = right_; + bounds_array[3] = bottom_; + } + + inline float GetWidth() const { + return right_ - left_; + } + + inline float GetHeight() const { + return bottom_ - top_; + } + + inline float GetArea() const { + const float width = GetWidth(); + const float height = GetHeight(); + if (width <= 0 || height <= 0) { + return 0.0f; + } + + return width * height; + } + + inline Point2f GetCenter() const { + return Point2f((left_ + right_) / 2.0f, + (top_ + bottom_) / 2.0f); + } + + inline bool ValidBox() const { + return GetArea() > 0.0f; + } + + // Returns a bounding box created from the overlapping area of these two. + inline BoundingBox Intersect(const BoundingBox& that) const { + const float new_left = MAX(this->left_, that.left_); + const float new_right = MIN(this->right_, that.right_); + + if (new_left >= new_right) { + return BoundingBox(); + } + + const float new_top = MAX(this->top_, that.top_); + const float new_bottom = MIN(this->bottom_, that.bottom_); + + if (new_top >= new_bottom) { + return BoundingBox(); + } + + return BoundingBox(new_left, new_top, new_right, new_bottom); + } + + // Returns a bounding box that can contain both boxes. + inline BoundingBox Union(const BoundingBox& that) const { + return BoundingBox(MIN(this->left_, that.left_), + MIN(this->top_, that.top_), + MAX(this->right_, that.right_), + MAX(this->bottom_, that.bottom_)); + } + + inline float PascalScore(const BoundingBox& that) const { + SCHECK(GetArea() > 0.0f, "Empty bounding box!"); + SCHECK(that.GetArea() > 0.0f, "Empty bounding box!"); + + const float intersect_area = this->Intersect(that).GetArea(); + + if (intersect_area <= 0) { + return 0; + } + + const float score = + intersect_area / (GetArea() + that.GetArea() - intersect_area); + SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score); + return score; + } + + inline bool Intersects(const BoundingBox& that) const { + return InRange(that.left_, left_, right_) + || InRange(that.right_, left_, right_) + || InRange(that.top_, top_, bottom_) + || InRange(that.bottom_, top_, bottom_); + } + + // Returns whether another bounding box is completely inside of this bounding + // box. Sharing edges is ok. + inline bool Contains(const BoundingBox& that) const { + return that.left_ >= left_ && + that.right_ <= right_ && + that.top_ >= top_ && + that.bottom_ <= bottom_; + } + + inline bool Contains(const Point2f& point) const { + return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_); + } + + inline void Shift(const Point2f shift_amount) { + left_ += shift_amount.x; + top_ += shift_amount.y; + right_ += shift_amount.x; + bottom_ += shift_amount.y; + } + + inline void ScaleOrigin(const float scale_x, const float scale_y) { + left_ *= scale_x; + right_ *= scale_x; + top_ *= scale_y; + bottom_ *= scale_y; + } + + inline void Scale(const float scale_x, const float scale_y) { + const Point2f center = GetCenter(); + const float half_width = GetWidth() / 2.0f; + const float half_height = GetHeight() / 2.0f; + + left_ = center.x - half_width * scale_x; + right_ = center.x + half_width * scale_x; + + top_ = center.y - half_height * scale_y; + bottom_ = center.y + half_height * scale_y; + } + + float left_; + float top_; + float right_; + float bottom_; +}; +inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) { + stream << "[" << box.left_ << " - " << box.right_ + << ", " << box.top_ << " - " << box.bottom_ + << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]"; + return stream; +} + + +class BoundingSquare { + public: + BoundingSquare(const float x, const float y, const float size) + : x_(x), y_(y), size_(size) {} + + explicit BoundingSquare(const BoundingBox& box) + : x_(box.left_), y_(box.top_), size_(box.GetWidth()) { +#ifdef SANITY_CHECKS + if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) { + LOG(WARNING) << "This is not a square: " << box << std::endl; + } +#endif + } + + inline BoundingBox ToBoundingBox() const { + return BoundingBox(x_, y_, x_ + size_, y_ + size_); + } + + inline bool ValidBox() { + return size_ > 0.0f; + } + + inline void Shift(const Point2f shift_amount) { + x_ += shift_amount.x; + y_ += shift_amount.y; + } + + inline void Scale(const float scale) { + const float new_size = size_ * scale; + const float position_diff = (new_size - size_) / 2.0f; + x_ -= position_diff; + y_ -= position_diff; + size_ = new_size; + } + + float x_; + float y_; + float size_; +}; +inline std::ostream& operator<<(std::ostream& stream, + const BoundingSquare& square) { + stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]"; + return stream; +} + + +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box, + const float size) { + const float width_diff = (original_box.GetWidth() - size) / 2.0f; + const float height_diff = (original_box.GetHeight() - size) / 2.0f; + return BoundingSquare(original_box.left_ + width_diff, + original_box.top_ + height_diff, + size); +} + +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) { + return GetCenteredSquare( + original_box, MIN(original_box.GetWidth(), original_box.GetHeight())); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h new file mode 100755 index 0000000000..bd5c233f4f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h @@ -0,0 +1,55 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ + +#include +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" + +namespace tf_tracking { + +// Draws a box at the given position. +inline static void DrawBox(const BoundingBox& bounding_box) { + const GLfloat line[] = { + bounding_box.left_, bounding_box.bottom_, + bounding_box.left_, bounding_box.top_, + bounding_box.left_, bounding_box.top_, + bounding_box.right_, bounding_box.top_, + bounding_box.right_, bounding_box.top_, + bounding_box.right_, bounding_box.bottom_, + bounding_box.right_, bounding_box.bottom_, + bounding_box.left_, bounding_box.bottom_ + }; + + glVertexPointer(2, GL_FLOAT, 0, line); + glEnableClientState(GL_VERTEX_ARRAY); + + glDrawArrays(GL_LINES, 0, 8); +} + + +// Changes the coordinate system such that drawing to an arbitrary square in +// the world can thereafter be drawn to using coordinates 0 - 1. +inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) { + glScalef(square.size_, square.size_, 1.0f); + glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h new file mode 100644 index 0000000000..18123cef01 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h @@ -0,0 +1,642 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +template +Image::Image(const int width, const int height) + : width_less_one_(width - 1), + height_less_one_(height - 1), + data_size_(width * height), + own_data_(true), + width_(width), + height_(height), + stride_(width) { + Allocate(); +} + +template +Image::Image(const Size& size) + : width_less_one_(size.width - 1), + height_less_one_(size.height - 1), + data_size_(size.width * size.height), + own_data_(true), + width_(size.width), + height_(size.height), + stride_(size.width) { + Allocate(); +} + +// Constructor that creates an image from preallocated data. +// Note: The image takes ownership of the data lifecycle, unless own_data is +// set to false. +template +Image::Image(const int width, const int height, T* const image_data, + const bool own_data) : + width_less_one_(width - 1), + height_less_one_(height - 1), + data_size_(width * height), + own_data_(own_data), + width_(width), + height_(height), + stride_(width) { + image_data_ = image_data; + SCHECK(image_data_ != NULL, "Can't create image with NULL data!"); +} + +template +Image::~Image() { + if (own_data_) { + delete[] image_data_; + } + image_data_ = NULL; +} + +template +template +bool Image::ExtractPatchAtSubpixelFixed1616(const int fp_x, + const int fp_y, + const int patchwidth, + const int patchheight, + DstType* to_data) const { + // Calculate weights. + const int trunc_x = fp_x >> 16; + const int trunc_y = fp_y >> 16; + + if (trunc_x < 0 || trunc_y < 0 || + (trunc_x + patchwidth) >= width_less_one_ || + (trunc_y + patchheight) >= height_less_one_) { + return false; + } + + // Now walk over destination patch and fill from interpolated source image. + for (int y = 0; y < patchheight; ++y, to_data += patchwidth) { + for (int x = 0; x < patchwidth; ++x) { + to_data[x] = + static_cast(GetPixelInterpFixed1616(fp_x + (x << 16), + fp_y + (y << 16))); + } + } + + return true; +} + +template +Image* Image::Crop( + const int left, const int top, const int right, const int bottom) const { + SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left); + SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right); + SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top); + SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom); + + SCHECK(left <= right, "mismatch!"); + SCHECK(top <= bottom, "mismatch!"); + + const int new_width = right - left + 1; + const int new_height = bottom - top + 1; + + Image* const cropped_image = new Image(new_width, new_height); + + for (int y = 0; y < new_height; ++y) { + memcpy((*cropped_image)[y], ((*this)[y + top] + left), + new_width * sizeof(T)); + } + + return cropped_image; +} + +template +inline float Image::GetPixelInterp(const float x, const float y) const { + // Do int conversion one time. + const int floored_x = static_cast(x); + const int floored_y = static_cast(y); + + // Note: it might be the case that the *_[min|max] values are clipped, and + // these (the a b c d vals) aren't (for speed purposes), but that doesn't + // matter. We'll just be blending the pixel with itself in that case anyway. + const float b = x - floored_x; + const float a = 1.0f - b; + + const float d = y - floored_y; + const float c = 1.0f - d; + + SCHECK(ValidInterpPixel(x, y), + "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)", + x, width_less_one_, y, height_less_one_); + + const T* const pix_ptr = (*this)[floored_y] + floored_x; + + // Get the pixel values surrounding this point. + const T& p1 = pix_ptr[0]; + const T& p2 = pix_ptr[1]; + const T& p3 = pix_ptr[width_]; + const T& p4 = pix_ptr[width_ + 1]; + + // Simple bilinear interpolation between four reference pixels. + // If x is the value requested: + // a b + // ------- + // c |p1 p2| + // | x | + // d |p3 p4| + // ------- + return c * ((a * p1) + (b * p2)) + + d * ((a * p3) + (b * p4)); +} + + +template +inline T Image::GetPixelInterpFixed1616( + const int fp_x_whole, const int fp_y_whole) const { + static const int kFixedPointOne = 0x00010000; + static const int kFixedPointHalf = 0x00008000; + static const int kFixedPointTruncateMask = 0xFFFF0000; + + int trunc_x = fp_x_whole & kFixedPointTruncateMask; + int trunc_y = fp_y_whole & kFixedPointTruncateMask; + const int fp_x = fp_x_whole - trunc_x; + const int fp_y = fp_y_whole - trunc_y; + + // Scale the truncated values back to regular ints. + trunc_x >>= 16; + trunc_y >>= 16; + + const int one_minus_fp_x = kFixedPointOne - fp_x; + const int one_minus_fp_y = kFixedPointOne - fp_y; + + const T* trunc_start = (*this)[trunc_y] + trunc_x; + + const T a = trunc_start[0]; + const T b = trunc_start[1]; + const T c = trunc_start[stride_]; + const T d = trunc_start[stride_ + 1]; + + return ((one_minus_fp_y * static_cast(one_minus_fp_x * a + fp_x * b) + + fp_y * static_cast(one_minus_fp_x * c + fp_x * d) + + kFixedPointHalf) >> 32); +} + +template +inline bool Image::ValidPixel(const int x, const int y) const { + return InRange(x, ZERO, width_less_one_) && + InRange(y, ZERO, height_less_one_); +} + +template +inline BoundingBox Image::GetContainingBox() const { + return BoundingBox( + 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON); +} + +template +inline bool Image::Contains(const BoundingBox& bounding_box) const { + // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds + // are ok. + return GetContainingBox().Contains(bounding_box); +} + +template +inline bool Image::ValidInterpPixel(const float x, const float y) const { + // Exclusive of max because we can be more efficient if we don't handle + // interpolating on or past the last pixel. + return (x >= ZERO) && (x < width_less_one_) && + (y >= ZERO) && (y < height_less_one_); +} + +template +void Image::DownsampleAveraged(const T* const original, const int stride, + const int factor) { +#ifdef __ARM_NEON + if (factor == 4 || factor == 2) { + DownsampleAveragedNeon(original, stride, factor); + return; + } +#endif + + // TODO(andrewharp): delete or enable this for non-uint8 downsamples. + const int pixels_per_block = factor * factor; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + const int orig_y = y * factor; + const int y_bound = orig_y + factor; + + // Sum up the original pixels. + for (int x = 0; x < width_; ++x) { + const int orig_x = x * factor; + const int x_bound = orig_x + factor; + + // Making this int32 because type U or T might overflow. + int32 pixel_sum = 0; + + // Grab all the pixels that make up this pixel. + for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) { + const T* p = original + curr_y * stride + orig_x; + + for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) { + pixel_sum += *p++; + } + } + + (*this)[y][x] = pixel_sum / pixels_per_block; + } + } +} + +template +void Image::DownsampleInterpolateNearest(const Image& original) { + // Calculating the scaling factors based on target image size. + const float factor_x = static_cast(original.GetWidth()) / + static_cast(width_); + const float factor_y = static_cast(original.GetHeight()) / + static_cast(height_); + + // Calculating initial offset in x-axis. + const float offset_x = 0.5f * (original.GetWidth() - width_) / width_; + + // Calculating initial offset in y-axis. + const float offset_y = 0.5f * (original.GetHeight() - height_) / height_; + + float orig_y = offset_y; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + float orig_x = offset_x; + + // Finding nearest pixel on y-axis. + const int nearest_y = static_cast(orig_y + 0.5f); + const T* row_data = original[nearest_y]; + + T* pixel_ptr = (*this)[y]; + + for (int x = 0; x < width_; ++x) { + // Finding nearest pixel on x-axis. + const int nearest_x = static_cast(orig_x + 0.5f); + + *pixel_ptr++ = row_data[nearest_x]; + + orig_x += factor_x; + } + + orig_y += factor_y; + } +} + +template +void Image::DownsampleInterpolateLinear(const Image& original) { + // TODO(andrewharp): Turn this into a general compare sizes/bulk + // copy method. + if (original.GetWidth() == GetWidth() && + original.GetHeight() == GetHeight() && + original.stride() == stride()) { + memcpy(image_data_, original.data(), data_size_ * sizeof(T)); + return; + } + + // Calculating the scaling factors based on target image size. + const float factor_x = static_cast(original.GetWidth()) / + static_cast(width_); + const float factor_y = static_cast(original.GetHeight()) / + static_cast(height_); + + // Calculating initial offset in x-axis. + const float offset_x = 0; + const int offset_x_fp = RealToFixed1616(offset_x); + + // Calculating initial offset in y-axis. + const float offset_y = 0; + const int offset_y_fp = RealToFixed1616(offset_y); + + // Get the fixed point scaling factor value. + // Shift by 8 so we can fit everything into a 4 byte int later for speed + // reasons. This means the precision is limited to 1 / 256th of a pixel, + // but this should be good enough. + const int factor_x_fp = RealToFixed1616(factor_x) >> 8; + const int factor_y_fp = RealToFixed1616(factor_y) >> 8; + + int src_y_fp = offset_y_fp >> 8; + + static const int kFixedPointOne8 = 0x00000100; + static const int kFixedPointHalf8 = 0x00000080; + static const int kFixedPointTruncateMask8 = 0xFFFFFF00; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + int src_x_fp = offset_x_fp >> 8; + + int trunc_y = src_y_fp & kFixedPointTruncateMask8; + const int fp_y = src_y_fp - trunc_y; + + // Scale the truncated values back to regular ints. + trunc_y >>= 8; + + const int one_minus_fp_y = kFixedPointOne8 - fp_y; + + T* pixel_ptr = (*this)[y]; + + // Make sure not to read from an invalid row. + const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1); + const T* other_top_ptr = original[trunc_y]; + const T* other_bot_ptr = original[trunc_y_b]; + + int last_trunc_x = -1; + int trunc_x = -1; + + T a = 0; + T b = 0; + T c = 0; + T d = 0; + + for (int x = 0; x < width_; ++x) { + trunc_x = src_x_fp & kFixedPointTruncateMask8; + + const int fp_x = (src_x_fp - trunc_x) >> 8; + + // Scale the truncated values back to regular ints. + trunc_x >>= 8; + + // It's possible we're reading from the same pixels + if (trunc_x != last_trunc_x) { + // Make sure not to read from an invalid column. + const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1); + a = other_top_ptr[trunc_x]; + b = other_top_ptr[trunc_x_b]; + c = other_bot_ptr[trunc_x]; + d = other_bot_ptr[trunc_x_b]; + last_trunc_x = trunc_x; + } + + const int one_minus_fp_x = kFixedPointOne8 - fp_x; + + const int32 value = + ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) + + (fp_y * one_minus_fp_x * c + fp_x * d) + + kFixedPointHalf8) >> 16; + + *pixel_ptr++ = value; + + src_x_fp += factor_x_fp; + } + src_y_fp += factor_y_fp; + } +} + +template +void Image::DownsampleSmoothed3x3(const Image& original) { + for (int y = 0; y < height_; ++y) { + const int orig_y = Clip(2 * y, ZERO, original.height_less_one_); + const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_); + const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_); + + for (int x = 0; x < width_; ++x) { + const int orig_x = Clip(2 * x, ZERO, original.width_less_one_); + const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_); + const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_); + + // Center. + int32 pixel_sum = original[orig_y][orig_x] * 4; + + // Sides. + pixel_sum += (original[orig_y][max_x] + + original[orig_y][min_x] + + original[max_y][orig_x] + + original[min_y][orig_x]) * 2; + + // Diagonals. + pixel_sum += (original[min_y][max_x] + + original[min_y][min_x] + + original[max_y][max_x] + + original[max_y][min_x]); + + (*this)[y][x] = pixel_sum >> 4; // 16 + } + } +} + +template +void Image::DownsampleSmoothed5x5(const Image& original) { + const int max_x = original.width_less_one_; + const int max_y = original.height_less_one_; + + // The JY Bouget paper on Lucas-Kanade recommends a + // [1/16 1/4 3/8 1/4 1/16]^2 filter. + // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below. + static const int window_radius = 2; + static const int window_size = window_radius*2 + 1; + static const int window_weights[] = {1, 4, 6, 4, 1, // 16 + + 4, 16, 24, 16, 4, // 64 + + 6, 24, 36, 24, 6, // 96 + + 4, 16, 24, 16, 4, // 64 + + 1, 4, 6, 4, 1}; // 16 = 256 + + // We'll multiply and sum with the the whole numbers first, then divide by + // the total weight to normalize at the last moment. + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + int32 pixel_sum = 0; + + const int* w = window_weights; + const int start_x = Clip((x << 1) - window_radius, ZERO, max_x); + + // Clip the boundaries to the size of the image. + for (int window_y = 0; window_y < window_size; ++window_y) { + const int start_y = + Clip((y << 1) - window_radius + window_y, ZERO, max_y); + + const T* p = original[start_y] + start_x; + + for (int window_x = 0; window_x < window_size; ++window_x) { + pixel_sum += *p++ * *w++; + } + } + + // Conversion to type T will happen here after shifting right 8 bits to + // divide by 256. + (*this)[y][x] = pixel_sum >> 8; + } + } +} + +template +template +inline T Image::ScharrPixelX(const Image& original, + const int center_x, const int center_y) const { + const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_); + const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_); + const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_); + const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_); + + // Convolution loop unrolled for performance... + return (3 * (original[min_y][max_x] + + original[max_y][max_x] + - original[min_y][min_x] + - original[max_y][min_x]) + + 10 * (original[center_y][max_x] + - original[center_y][min_x])) / 32; +} + +template +template +inline T Image::ScharrPixelY(const Image& original, + const int center_x, const int center_y) const { + const int min_x = Clip(center_x - 1, 0, original.width_less_one_); + const int max_x = Clip(center_x + 1, 0, original.width_less_one_); + const int min_y = Clip(center_y - 1, 0, original.height_less_one_); + const int max_y = Clip(center_y + 1, 0, original.height_less_one_); + + // Convolution loop unrolled for performance... + return (3 * (original[max_y][min_x] + + original[max_y][max_x] + - original[min_y][min_x] + - original[min_y][max_x]) + + 10 * (original[max_y][center_x] + - original[min_y][center_x])) / 32; +} + +template +template +inline void Image::ScharrX(const Image& original) { + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ScharrPixelX(original, x, y)); + } + } +} + +template +template +inline void Image::ScharrY(const Image& original) { + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ScharrPixelY(original, x, y)); + } + } +} + +template +template +void Image::DerivativeX(const Image& original) { + for (int y = 0; y < height_; ++y) { + const U* const source_row = original[y]; + T* const dest_row = (*this)[y]; + + // Compute first pixel. Approximated with forward difference. + dest_row[0] = source_row[1] - source_row[0]; + + // All the pixels in between. Central difference method. + const U* source_prev_pixel = source_row; + T* dest_pixel = dest_row + 1; + const U* source_next_pixel = source_row + 2; + for (int x = 1; x < width_less_one_; ++x) { + *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); + } + + // Last pixel. Approximated with backward difference. + dest_row[width_less_one_] = + source_row[width_less_one_] - source_row[width_less_one_ - 1]; + } +} + +template +template +void Image::DerivativeY(const Image& original) { + const int src_stride = original.stride(); + + // Compute 1st row. Approximated with forward difference. + { + const U* const src_row = original[0]; + T* dest_row = (*this)[0]; + for (int x = 0; x < width_; ++x) { + dest_row[x] = src_row[x + src_stride] - src_row[x]; + } + } + + // Compute all rows in between using central difference. + for (int y = 1; y < height_less_one_; ++y) { + T* dest_row = (*this)[y]; + + const U* source_prev_pixel = original[y - 1]; + const U* source_next_pixel = original[y + 1]; + for (int x = 0; x < width_; ++x) { + *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); + } + } + + // Compute last row. Approximated with backward difference. + { + const U* const src_row = original[height_less_one_]; + T* dest_row = (*this)[height_less_one_]; + for (int x = 0; x < width_; ++x) { + dest_row[x] = src_row[x] - src_row[x - src_stride]; + } + } +} + +template +template +inline T Image::ConvolvePixel3x3(const Image& original, + const int* const filter, + const int center_x, const int center_y, + const int total) const { + int32 sum = 0; + for (int filter_y = 0; filter_y < 3; ++filter_y) { + const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight()); + for (int filter_x = 0; filter_x < 3; ++filter_x) { + const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth()); + sum += original[y][x] * filter[filter_y * 3 + filter_x]; + } + } + return sum / total; +} + +template +template +inline void Image::Convolve3x3(const Image& original, + const int32* const filter) { + int32 sum = 0; + for (int i = 0; i < 9; ++i) { + sum += abs(filter[i]); + } + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum)); + } + } +} + +template +inline void Image::FromArray(const T* const pixels, const int stride, + const int factor) { + if (factor == 1 && stride == width_) { + // If not subsampling, memcpy per line should be faster. + memcpy(this->image_data_, pixels, data_size_ * sizeof(T)); + return; + } + + DownsampleAveraged(pixels, stride, factor); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h new file mode 100644 index 0000000000..29b0adbda8 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image.h @@ -0,0 +1,346 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +// TODO(andrewharp): Make this a cast to uint32 if/when we go unsigned for +// operations. +#define ZERO 0 + +#ifdef SANITY_CHECKS + #define CHECK_PIXEL(IMAGE, X, Y) {\ + SCHECK((IMAGE)->ValidPixel((X), (Y)), \ + "CHECK_PIXEL(%d,%d) in %dx%d image.", \ + static_cast(X), static_cast(Y), \ + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ + } + + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\ + SCHECK((IMAGE)->validInterpPixel((X), (Y)), \ + "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \ + static_cast(X), static_cast(Y), \ + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ + } +#else + #define CHECK_PIXEL(image, x, y) {} + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {} +#endif + +namespace tf_tracking { + +#ifdef SANITY_CHECKS +// Class which exists solely to provide bounds checking for array-style image +// data access. +template +class RowData { + public: + RowData(T* const row_data, const int max_col) + : row_data_(row_data), max_col_(max_col) {} + + inline T& operator[](const int col) const { + SCHECK(InRange(col, 0, max_col_), + "Column out of range: %d (%d max)", col, max_col_); + return row_data_[col]; + } + + inline operator T*() const { + return row_data_; + } + + private: + T* const row_data_; + const int max_col_; +}; +#endif + +// Naive templated sorting function. +template +int Comp(const void* a, const void* b) { + const T val1 = *reinterpret_cast(a); + const T val2 = *reinterpret_cast(b); + + if (val1 == val2) { + return 0; + } else if (val1 < val2) { + return -1; + } else { + return 1; + } +} + +// TODO(andrewharp): Make explicit which operations support negative numbers or +// struct/class types in image data (possibly create fast multi-dim array class +// for data where pixel arithmetic does not make sense). + +// Image class optimized for working on numeric arrays as grayscale image data. +// Supports other data types as a 2D array class, so long as no pixel math +// operations are called (convolution, downsampling, etc). +template +class Image { + public: + Image(const int width, const int height); + explicit Image(const Size& size); + + // Constructor that creates an image from preallocated data. + // Note: The image takes ownership of the data lifecycle, unless own_data is + // set to false. + Image(const int width, const int height, T* const image_data, + const bool own_data = true); + + ~Image(); + + // Extract a pixel patch from this image, starting at a subpixel location. + // Uses 16:16 fixed point format for representing real values and doing the + // bilinear interpolation. + // + // Arguments fp_x and fp_y tell the subpixel position in fixed point format, + // patchwidth/patchheight give the size of the patch in pixels and + // to_data must be a valid pointer to a *contiguous* destination data array. + template + bool ExtractPatchAtSubpixelFixed1616(const int fp_x, + const int fp_y, + const int patchwidth, + const int patchheight, + DstType* to_data) const; + + Image* Crop( + const int left, const int top, const int right, const int bottom) const; + + inline int GetWidth() const { return width_; } + inline int GetHeight() const { return height_; } + + // Bilinearly sample a value between pixels. Values must be within the image. + inline float GetPixelInterp(const float x, const float y) const; + + // Bilinearly sample a pixels at a subpixel position using fixed point + // arithmetic. + // Avoids float<->int conversions. + // Values must be within the image. + // Arguments fp_x and fp_y tell the subpixel position in + // 16:16 fixed point format. + // + // Important: This function only makes sense for integer-valued images, such + // as Image or Image etc. + inline T GetPixelInterpFixed1616(const int fp_x_whole, + const int fp_y_whole) const; + + // Returns true iff the pixel is in the image's boundaries. + inline bool ValidPixel(const int x, const int y) const; + + inline BoundingBox GetContainingBox() const; + + inline bool Contains(const BoundingBox& bounding_box) const; + + inline T GetMedianValue() { + qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp); + return image_data_[data_size_ >> 1]; + } + + // Returns true iff the pixel is in the image's boundaries for interpolation + // purposes. + // TODO(andrewharp): check in interpolation follow-up change. + inline bool ValidInterpPixel(const float x, const float y) const; + + // Safe lookup with boundary enforcement. + inline T GetPixelClipped(const int x, const int y) const { + return (*this)[Clip(y, ZERO, height_less_one_)] + [Clip(x, ZERO, width_less_one_)]; + } + +#ifdef SANITY_CHECKS + inline RowData operator[](const int row) { + SCHECK(InRange(row, 0, height_less_one_), + "Row out of range: %d (%d max)", row, height_less_one_); + return RowData(image_data_ + row * stride_, width_less_one_); + } + + inline const RowData operator[](const int row) const { + SCHECK(InRange(row, 0, height_less_one_), + "Row out of range: %d (%d max)", row, height_less_one_); + return RowData(image_data_ + row * stride_, width_less_one_); + } +#else + inline T* operator[](const int row) { + return image_data_ + row * stride_; + } + + inline const T* operator[](const int row) const { + return image_data_ + row * stride_; + } +#endif + + const T* data() const { return image_data_; } + + inline int stride() const { return stride_; } + + // Clears image to a single value. + inline void Clear(const T& val) { + memset(image_data_, val, sizeof(*image_data_) * data_size_); + } + +#ifdef __ARM_NEON + void Downsample2x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x); + + void Downsample4x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x); + + void DownsampleAveragedNeon(const uint8* const original, const int stride, + const int factor); +#endif + + // Naive downsampler that reduces image size by factor by averaging pixels in + // blocks of size factor x factor. + void DownsampleAveraged(const T* const original, const int stride, + const int factor); + + // Naive downsampler that reduces image size by factor by averaging pixels in + // blocks of size factor x factor. + inline void DownsampleAveraged(const Image& original, const int factor) { + DownsampleAveraged(original.data(), original.GetWidth(), factor); + } + + // Native downsampler that reduces image size using nearest interpolation + void DownsampleInterpolateNearest(const Image& original); + + // Native downsampler that reduces image size using fixed-point bilinear + // interpolation + void DownsampleInterpolateLinear(const Image& original); + + // Relatively efficient downsampling of an image by a factor of two with a + // low-pass 3x3 smoothing operation thrown in. + void DownsampleSmoothed3x3(const Image& original); + + // Relatively efficient downsampling of an image by a factor of two with a + // low-pass 5x5 smoothing operation thrown in. + void DownsampleSmoothed5x5(const Image& original); + + // Optimized Scharr filter on a single pixel in the X direction. + // Scharr filters are like central-difference operators, but have more + // rotational symmetry in their response because they also consider the + // diagonal neighbors. + template + inline T ScharrPixelX(const Image& original, + const int center_x, const int center_y) const; + + // Optimized Scharr filter on a single pixel in the X direction. + // Scharr filters are like central-difference operators, but have more + // rotational symmetry in their response because they also consider the + // diagonal neighbors. + template + inline T ScharrPixelY(const Image& original, + const int center_x, const int center_y) const; + + // Convolve the image with a Scharr filter in the X direction. + // Much faster than an equivalent generic convolution. + template + inline void ScharrX(const Image& original); + + // Convolve the image with a Scharr filter in the Y direction. + // Much faster than an equivalent generic convolution. + template + inline void ScharrY(const Image& original); + + static inline T HalfDiff(int32 first, int32 second) { + return (second - first) / 2; + } + + template + void DerivativeX(const Image& original); + + template + void DerivativeY(const Image& original); + + // Generic function for convolving pixel with 3x3 filter. + // Filter pixels should be in row major order. + template + inline T ConvolvePixel3x3(const Image& original, + const int* const filter, + const int center_x, const int center_y, + const int total) const; + + // Generic function for convolving an image with a 3x3 filter. + // TODO(andrewharp): Generalize this for any size filter. + template + inline void Convolve3x3(const Image& original, + const int32* const filter); + + // Load this image's data from a data array. The data at pixels is assumed to + // have dimensions equivalent to this image's dimensions * factor. + inline void FromArray(const T* const pixels, const int stride, + const int factor = 1); + + // Copy the image back out to an appropriately sized data array. + inline void ToArray(T* const pixels) const { + // If not subsampling, memcpy should be faster. + memcpy(pixels, this->image_data_, data_size_ * sizeof(T)); + } + + // Precompute these for efficiency's sake as they're used by a lot of + // clipping code and loop code. + // TODO(andrewharp): make these only accessible by other Images. + const int width_less_one_; + const int height_less_one_; + + // The raw size of the allocated data. + const int data_size_; + + private: + inline void Allocate() { + image_data_ = new T[data_size_]; + if (image_data_ == NULL) { + LOGE("Couldn't allocate image data!"); + } + } + + T* image_data_; + + bool own_data_; + + const int width_; + const int height_; + + // The image stride (offset to next row). + // TODO(andrewharp): Make sure that stride is honored in all code. + const int stride_; + + TF_DISALLOW_COPY_AND_ASSIGN(Image); +}; + +template +inline std::ostream& operator<<(std::ostream& stream, const Image& image) { + for (int y = 0; y < image.GetHeight(); ++y) { + for (int x = 0; x < image.GetWidth(); ++x) { + stream << image[y][x] << " "; + } + stream << std::endl; + } + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h new file mode 100644 index 0000000000..16b1864ee6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_data.h @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" + +using namespace tensorflow; + +namespace tf_tracking { + +// Class that encapsulates all bulky processed data for a frame. +class ImageData { + public: + explicit ImageData(const int width, const int height) + : uv_frame_width_(width << 1), + uv_frame_height_(height << 1), + timestamp_(0), + image_(width, height) { + InitPyramid(width, height); + ResetComputationCache(); + } + + private: + void ResetComputationCache() { + uv_data_computed_ = false; + integral_image_computed_ = false; + for (int i = 0; i < kNumPyramidLevels; ++i) { + spatial_x_computed_[i] = false; + spatial_y_computed_[i] = false; + pyramid_sqrt2_computed_[i * 2] = false; + pyramid_sqrt2_computed_[i * 2 + 1] = false; + } + } + + void InitPyramid(const int width, const int height) { + int level_width = width; + int level_height = height; + + for (int i = 0; i < kNumPyramidLevels; ++i) { + pyramid_sqrt2_[i * 2] = NULL; + pyramid_sqrt2_[i * 2 + 1] = NULL; + spatial_x_[i] = NULL; + spatial_y_[i] = NULL; + + level_width /= 2; + level_height /= 2; + } + + // Alias the first pyramid level to image_. + pyramid_sqrt2_[0] = &image_; + } + + public: + ~ImageData() { + // The first pyramid level is actually an alias to image_, + // so make sure it doesn't get deleted here. + pyramid_sqrt2_[0] = NULL; + + for (int i = 0; i < kNumPyramidLevels; ++i) { + SAFE_DELETE(pyramid_sqrt2_[i * 2]); + SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]); + SAFE_DELETE(spatial_x_[i]); + SAFE_DELETE(spatial_y_[i]); + } + } + + void SetData(const uint8* const new_frame, const int stride, + const int64 timestamp, const int downsample_factor) { + SetData(new_frame, NULL, stride, timestamp, downsample_factor); + } + + void SetData(const uint8* const new_frame, + const uint8* const uv_frame, + const int stride, + const int64 timestamp, const int downsample_factor) { + ResetComputationCache(); + + timestamp_ = timestamp; + + TimeLog("SetData!"); + + pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor); + pyramid_sqrt2_computed_[0] = true; + TimeLog("Downsampled image"); + + if (uv_frame != NULL) { + if (u_data_.get() == NULL) { + u_data_.reset(new Image(uv_frame_width_, uv_frame_height_)); + v_data_.reset(new Image(uv_frame_width_, uv_frame_height_)); + } + + GetUV(uv_frame, u_data_.get(), v_data_.get()); + uv_data_computed_ = true; + TimeLog("Copied UV data"); + } else { + LOGV("No uv data!"); + } + +#ifdef LOG_TIME + // If profiling is enabled, precompute here to make it easier to distinguish + // total costs. + Precompute(); +#endif + } + + inline const uint64 GetTimestamp() const { + return timestamp_; + } + + inline const Image* GetImage() const { + SCHECK(pyramid_sqrt2_computed_[0], "image not set!"); + return pyramid_sqrt2_[0]; + } + + const Image* GetPyramidSqrt2Level(const int level) const { + if (!pyramid_sqrt2_computed_[level]) { + SCHECK(level != 0, "Level equals 0!"); + if (level == 1) { + const Image& upper_level = *GetPyramidSqrt2Level(0); + if (pyramid_sqrt2_[level] == NULL) { + const int new_width = + (static_cast(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2; + const int new_height = + (static_cast(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 * + 2; + + pyramid_sqrt2_[level] = new Image(new_width, new_height); + } + pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level); + } else { + const Image& upper_level = *GetPyramidSqrt2Level(level - 2); + if (pyramid_sqrt2_[level] == NULL) { + pyramid_sqrt2_[level] = new Image( + upper_level.GetWidth() / 2, upper_level.GetHeight() / 2); + } + pyramid_sqrt2_[level]->DownsampleAveraged( + upper_level.data(), upper_level.stride(), 2); + } + pyramid_sqrt2_computed_[level] = true; + } + return pyramid_sqrt2_[level]; + } + + inline const Image* GetSpatialX(const int level) const { + if (!spatial_x_computed_[level]) { + const Image& src = *GetPyramidSqrt2Level(level * 2); + if (spatial_x_[level] == NULL) { + spatial_x_[level] = new Image(src.GetWidth(), src.GetHeight()); + } + spatial_x_[level]->DerivativeX(src); + spatial_x_computed_[level] = true; + } + return spatial_x_[level]; + } + + inline const Image* GetSpatialY(const int level) const { + if (!spatial_y_computed_[level]) { + const Image& src = *GetPyramidSqrt2Level(level * 2); + if (spatial_y_[level] == NULL) { + spatial_y_[level] = new Image(src.GetWidth(), src.GetHeight()); + } + spatial_y_[level]->DerivativeY(src); + spatial_y_computed_[level] = true; + } + return spatial_y_[level]; + } + + // The integral image is currently only used for object detection, so lazily + // initialize it on request. + inline const IntegralImage* GetIntegralImage() const { + if (integral_image_.get() == NULL) { + integral_image_.reset(new IntegralImage(image_)); + } else if (!integral_image_computed_) { + integral_image_->Recompute(image_); + } + integral_image_computed_ = true; + return integral_image_.get(); + } + + inline const Image* GetU() const { + SCHECK(uv_data_computed_, "UV data not provided!"); + return u_data_.get(); + } + + inline const Image* GetV() const { + SCHECK(uv_data_computed_, "UV data not provided!"); + return v_data_.get(); + } + + private: + void Precompute() { + // Create the smoothed pyramids. + for (int i = 0; i < kNumPyramidLevels * 2; i += 2) { + (void) GetPyramidSqrt2Level(i); + } + TimeLog("Created smoothed pyramids"); + + // Create the smoothed pyramids. + for (int i = 1; i < kNumPyramidLevels * 2; i += 2) { + (void) GetPyramidSqrt2Level(i); + } + TimeLog("Created smoothed sqrt pyramids"); + + // Create the spatial derivatives for frame 1. + for (int i = 0; i < kNumPyramidLevels; ++i) { + (void) GetSpatialX(i); + (void) GetSpatialY(i); + } + TimeLog("Created spatial derivatives"); + + (void) GetIntegralImage(); + TimeLog("Got integral image!"); + } + + const int uv_frame_width_; + const int uv_frame_height_; + + int64 timestamp_; + + Image image_; + + bool uv_data_computed_; + std::unique_ptr > u_data_; + std::unique_ptr > v_data_; + + mutable bool spatial_x_computed_[kNumPyramidLevels]; + mutable Image* spatial_x_[kNumPyramidLevels]; + + mutable bool spatial_y_computed_[kNumPyramidLevels]; + mutable Image* spatial_y_[kNumPyramidLevels]; + + // Mutable so the lazy initialization can work when this class is const. + // Whether or not the integral image has been computed for the current image. + mutable bool integral_image_computed_; + mutable std::unique_ptr integral_image_; + + mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2]; + mutable Image* pyramid_sqrt2_[kNumPyramidLevels * 2]; + + TF_DISALLOW_COPY_AND_ASSIGN(ImageData); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_neon.cc b/tensorflow/examples/android/jni/object_tracking/image_neon.cc new file mode 100644 index 0000000000..ddd8447bf3 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_neon.cc @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NEON implementations of Image methods for compatible devices. Control +// should never enter this compilation unit on incompatible devices. + +#ifdef __ARM_NEON + +#include + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +namespace tf_tracking { + +// This function does the bulk of the work. +template <> +void Image::Downsample2x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x) { + // Divide input x offset by 2 to find output offset. + const int new_x = orig_x >> 1; + + // Initial offset into top row. + const uint8* offset = original + orig_x; + + // This points to the leftmost pixel of our 8 horizontally arranged + // pixels in the destination data. + uint8* ptr_dst = (*this)[0] + new_x; + + // Sum along vertical columns. + // Process 32x2 input pixels and 16x1 output pixels per iteration. + for (int new_y = 0; new_y < height_; ++new_y) { + uint16x8_t accum1 = vdupq_n_u16(0); + uint16x8_t accum2 = vdupq_n_u16(0); + + // Go top to bottom across the four rows of input pixels that make up + // this output row. + for (int row_num = 0; row_num < 2; ++row_num) { + // First 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data1 = vld1q_u8(offset); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum1 = vpadalq_u8(accum1, curr_data1); + } + + // Second 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum2 = vpadalq_u8(accum2, curr_data2); + } + + // Move offset down one row. + offset += stride; + } + + // Divide by 4 (number of input pixels per output + // pixel) and narrow data from 16 bits per pixel to 8 bpp. + const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2); + const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2); + + // Concatenate 8x1 pixel strips into 16x1 pixel strip. + const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2); + + // Copy all pixels from composite 16x1 vector into output strip. + vst1q_u8(ptr_dst, allpixels); + + ptr_dst += stride_; + } +} + +// This function does the bulk of the work. +template <> +void Image::Downsample4x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x) { + // Divide input x offset by 4 to find output offset. + const int new_x = orig_x >> 2; + + // Initial offset into top row. + const uint8* offset = original + orig_x; + + // This points to the leftmost pixel of our 8 horizontally arranged + // pixels in the destination data. + uint8* ptr_dst = (*this)[0] + new_x; + + // Sum along vertical columns. + // Process 32x4 input pixels and 8x1 output pixels per iteration. + for (int new_y = 0; new_y < height_; ++new_y) { + uint16x8_t accum1 = vdupq_n_u16(0); + uint16x8_t accum2 = vdupq_n_u16(0); + + // Go top to bottom across the four rows of input pixels that make up + // this output row. + for (int row_num = 0; row_num < 4; ++row_num) { + // First 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data1 = vld1q_u8(offset); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum1 = vpadalq_u8(accum1, curr_data1); + } + + // Second 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum2 = vpadalq_u8(accum2, curr_data2); + } + + // Move offset down one row. + offset += stride; + } + + // Add and widen, then divide by 16 (number of input pixels per output + // pixel) and narrow data from 32 bits per pixel to 16 bpp. + const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4); + const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4); + + // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from + // 16 bits to 8 bits per pixel. + const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2)); + + // Copy all pixels from composite 8x1 vector into output strip. + vst1_u8(ptr_dst, allpixels); + + ptr_dst += stride_; + } +} + + +// Hardware accelerated downsampling method for supported devices. +// Requires that image size be a multiple of 16 pixels in each dimension, +// and that downsampling be by a factor of 2 or 4. +template <> +void Image::DownsampleAveragedNeon(const uint8* const original, + const int stride, const int factor) { + // TODO(andrewharp): stride is a bad approximation for the src image's width. + // Better to pass that in directly. + SCHECK(width_ * factor <= stride, "Uh oh!"); + const int last_starting_index = width_ * factor - 32; + + // We process 32 input pixels lengthwise at a time. + // The output per pass of this loop is an 8 wide by downsampled height tall + // pixel strip. + int orig_x = 0; + for (; orig_x <= last_starting_index; orig_x += 32) { + if (factor == 2) { + Downsample2x32ColumnsNeon(original, stride, orig_x); + } else { + Downsample4x32ColumnsNeon(original, stride, orig_x); + } + } + + // If a last pass is required, push it to the left enough so that it never + // goes out of bounds. This will result in some extra computation on devices + // whose frame widths are multiples of 16 and not 32. + if (orig_x < last_starting_index + 32) { + if (factor == 2) { + Downsample2x32ColumnsNeon(original, stride, last_starting_index); + } else { + Downsample4x32ColumnsNeon(original, stride, last_starting_index); + } + } +} + + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// vals_x should be an array of the window x gradient values, whose indices +// can be in any order but are parallel to the vals_y entries. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. +void CalculateGNeon(const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G) { + const float32_t* const arm_vals_x = (const float32_t*) vals_x; + const float32_t* const arm_vals_y = (const float32_t*) vals_y; + + // Running sums. + float32x4_t xx = vdupq_n_f32(0.0f); + float32x4_t xy = vdupq_n_f32(0.0f); + float32x4_t yy = vdupq_n_f32(0.0f); + + // Maximum index we can load 4 consecutive values from. + // e.g. if there are 81 values, our last full pass can be from index 77: + // 81-4=>77 (77, 78, 79, 80) + const int max_i = num_vals - 4; + + // Defined here because we want to keep track of how many values were + // processed by NEON, so that we can finish off the remainder the normal + // way. + int i = 0; + + // Process values 4 at a time, accumulating the sums of + // the pixel-wise x*x, x*y, and y*y values. + for (; i <= max_i; i += 4) { + // Load xs + float32x4_t x = vld1q_f32(arm_vals_x + i); + + // Multiply x*x and accumulate. + xx = vmlaq_f32(xx, x, x); + + // Load ys + float32x4_t y = vld1q_f32(arm_vals_y + i); + + // Multiply x*y and accumulate. + xy = vmlaq_f32(xy, x, y); + + // Multiply y*y and accumulate. + yy = vmlaq_f32(yy, y, y); + } + + static float32_t xx_vals[4]; + static float32_t xy_vals[4]; + static float32_t yy_vals[4]; + + vst1q_f32(xx_vals, xx); + vst1q_f32(xy_vals, xy); + vst1q_f32(yy_vals, yy); + + // Accumulated values are store in sets of 4, we have to manually add + // the last bits together. + for (int j = 0; j < 4; ++j) { + G[0] += xx_vals[j]; + G[1] += xy_vals[j]; + G[3] += yy_vals[j]; + } + + // Finishes off last few values (< 4) from above. + for (; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + +} // namespace tf_tracking + +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h new file mode 100644 index 0000000000..5357a9352f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h @@ -0,0 +1,301 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +namespace tf_tracking { + +inline void GetUV( + const uint8* const input, Image* const u, Image* const v) { + const uint8* pUV = input; + + for (int row = 0; row < u->GetHeight(); ++row) { + uint8* u_curr = (*u)[row]; + uint8* v_curr = (*v)[row]; + for (int col = 0; col < u->GetWidth(); ++col) { +#ifdef __APPLE__ + *u_curr++ = *pUV++; + *v_curr++ = *pUV++; +#else + *v_curr++ = *pUV++; + *u_curr++ = *pUV++; +#endif + } + } +} + +// Marks every point within a circle of a given radius on the given boolean +// image true. +template +inline static void MarkImage(const int x, const int y, const int radius, + Image* const img) { + SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y); + + // Precomputed for efficiency. + const int squared_radius = Square(radius); + + // Mark every row in the circle. + for (int d_y = 0; d_y <= radius; ++d_y) { + const int squared_y_dist = Square(d_y); + + const int min_y = MAX(y - d_y, 0); + const int max_y = MIN(y + d_y, img->height_less_one_); + + // The max d_x of the circle must be strictly greater or equal to + // radius - d_y for any positive d_y. Thus, starting from radius - d_y will + // reduce the number of iterations required as compared to starting from + // either 0 and counting up or radius and counting down. + for (int d_x = radius - d_y; d_x <= radius; ++d_x) { + // The first time this critera is met, we know the width of the circle at + // this row (without using sqrt). + if (squared_y_dist + Square(d_x) >= squared_radius) { + const int min_x = MAX(x - d_x, 0); + const int max_x = MIN(x + d_x, img->width_less_one_); + + // Mark both above and below the center row. + bool* const top_row_start = (*img)[min_y] + min_x; + bool* const bottom_row_start = (*img)[max_y] + min_x; + + const int x_width = max_x - min_x + 1; + memset(top_row_start, true, sizeof(*top_row_start) * x_width); + memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width); + + // This row is marked, time to move on to the next row. + break; + } + } + } +} + +#ifdef __ARM_NEON +void CalculateGNeon( + const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G); +#endif + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// vals_x should be an array of the window x gradient values, whose indices +// can be in any order but are parallel to the vals_y entries. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. +inline void CalculateG(const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G) { +#ifdef __ARM_NEON + CalculateGNeon(vals_x, vals_y, num_vals, G); + return; +#endif + + // Non-accelerated version. + for (int i = 0; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + + +inline void CalculateGInt16(const int16* const vals_x, + const int16* const vals_y, + const int num_vals, int* const G) { + // Non-accelerated version. + for (int i = 0; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// Looks up interpolated pixels, then calls above method for implementation. +inline void CalculateG(const int window_radius, + const float center_x, const float center_y, + const Image& I_x, const Image& I_y, + float* const G) { + SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!"); + + // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels). + static const int kMaxWindowRadius = 5; + SCHECK(window_radius <= kMaxWindowRadius, + "Window %d > %d!", window_radius, kMaxWindowRadius); + + // Diameter of window is 2 * radius + 1 for center pixel. + static const int kWindowBufferSize = + (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1); + + // Preallocate buffers statically for efficiency. + static int16 vals_x[kWindowBufferSize]; + static int16 vals_y[kWindowBufferSize]; + + const int src_left_fixed = RealToFixed1616(center_x - window_radius); + const int src_top_fixed = RealToFixed1616(center_y - window_radius); + + int16* vals_x_ptr = vals_x; + int16* vals_y_ptr = vals_y; + + const int window_size = 2 * window_radius + 1; + for (int y = 0; y < window_size; ++y) { + const int fp_y = src_top_fixed + (y << 16); + + for (int x = 0; x < window_size; ++x) { + const int fp_x = src_left_fixed + (x << 16); + + *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); + } + } + + int32 g_temp[] = {0, 0, 0, 0}; + CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp); + + for (int i = 0; i < 4; ++i) { + G[i] = g_temp[i]; + } +} + +inline float ImageCrossCorrelation(const Image& image1, + const Image& image2, + const int x_offset, const int y_offset) { + SCHECK(image1.GetWidth() == image2.GetWidth() && + image1.GetHeight() == image2.GetHeight(), + "Dimension mismatch! %dx%d vs %dx%d", + image1.GetWidth(), image1.GetHeight(), + image2.GetWidth(), image2.GetHeight()); + + const int num_pixels = image1.GetWidth() * image1.GetHeight(); + const float* data1 = image1.data(); + const float* data2 = image2.data(); + return ComputeCrossCorrelation(data1, data2, num_pixels); +} + +// Copies an arbitrary region of an image to another (floating point) +// image, scaling as it goes using bilinear interpolation. +inline void CopyArea(const Image& image, + const BoundingBox& area_to_copy, + Image* const patch_image) { + VLOG(2) << "Copying from: " << area_to_copy << std::endl; + + const int patch_width = patch_image->GetWidth(); + const int patch_height = patch_image->GetHeight(); + + const float x_dist_between_samples = patch_width > 0 ? + area_to_copy.GetWidth() / (patch_width - 1) : 0; + + const float y_dist_between_samples = patch_height > 0 ? + area_to_copy.GetHeight() / (patch_height - 1) : 0; + + for (int y_index = 0; y_index < patch_height; ++y_index) { + const float sample_y = + y_index * y_dist_between_samples + area_to_copy.top_; + + for (int x_index = 0; x_index < patch_width; ++x_index) { + const float sample_x = + x_index * x_dist_between_samples + area_to_copy.left_; + + if (image.ValidInterpPixel(sample_x, sample_y)) { + // TODO(andrewharp): Do area averaging when downsampling. + (*patch_image)[y_index][x_index] = + image.GetPixelInterp(sample_x, sample_y); + } else { + (*patch_image)[y_index][x_index] = -1.0f; + } + } + } +} + + +// Takes a floating point image and normalizes it in-place. +// +// First, negative values will be set to the mean of the non-negative pixels +// in the image. +// +// Then, the resulting will be normalized such that it has mean value of 0.0 and +// a standard deviation of 1.0. +inline void NormalizeImage(Image* const image) { + const float* const data_ptr = image->data(); + + // Copy only the non-negative values to some temp memory. + float running_sum = 0.0f; + int num_data_gte_zero = 0; + { + float* const curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + if (curr_data[i] >= 0.0f) { + running_sum += curr_data[i]; + ++num_data_gte_zero; + } else { + curr_data[i] = -1.0f; + } + } + } + + // If none of the pixels are valid, just set the entire thing to 0.0f. + if (num_data_gte_zero == 0) { + image->Clear(0.0f); + return; + } + + const float corrected_mean = running_sum / num_data_gte_zero; + + float* curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + const float curr_val = *curr_data; + *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean; + } + + const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f); + + if (std_dev > 0.0f) { + curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + *curr_data++ /= std_dev; + } + +#ifdef SANITY_CHECKS + LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev); + const float correlation = + ComputeCrossCorrelation(image->data(), + image->data(), + image->data_size_); + + if (std::abs(correlation - 1.0f) > EPSILON) { + LOG(ERROR) << "Bad image!" << std::endl; + LOG(ERROR) << *image << std::endl; + } + + SCHECK(std::abs(correlation - 1.0f) < EPSILON, + "Correlation wasn't 1.0f: %.10f", correlation); +#endif + } +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h new file mode 100755 index 0000000000..28b2045572 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h @@ -0,0 +1,187 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +typedef uint8 Code; + +class IntegralImage: public Image { + public: + explicit IntegralImage(const Image& image_base) : + Image(image_base.GetWidth(), image_base.GetHeight()) { + Recompute(image_base); + } + + IntegralImage(const int width, const int height) : + Image(width, height) {} + + void Recompute(const Image& image_base) { + SCHECK(image_base.GetWidth() == GetWidth() && + image_base.GetHeight() == GetHeight(), "Dimensions don't match!"); + + // Sum along first row. + { + int x_sum = 0; + for (int x = 0; x < image_base.GetWidth(); ++x) { + x_sum += image_base[0][x]; + (*this)[0][x] = x_sum; + } + } + + // Sum everything else. + for (int y = 1; y < image_base.GetHeight(); ++y) { + uint32* curr_sum = (*this)[y]; + + // Previously summed pointers. + const uint32* up_one = (*this)[y - 1]; + + // Current value pointer. + const uint8* curr_delta = image_base[y]; + + uint32 row_till_now = 0; + + for (int x = 0; x < GetWidth(); ++x) { + // Add the one above and the one to the left. + row_till_now += *curr_delta; + *curr_sum = *up_one + row_till_now; + + // Scoot everything along. + ++curr_sum; + ++up_one; + ++curr_delta; + } + } + + SCHECK(VerifyData(image_base), "Images did not match!"); + } + + bool VerifyData(const Image& image_base) { + for (int y = 0; y < GetHeight(); ++y) { + for (int x = 0; x < GetWidth(); ++x) { + uint32 curr_val = (*this)[y][x]; + + if (x > 0) { + curr_val -= (*this)[y][x - 1]; + } + + if (y > 0) { + curr_val -= (*this)[y - 1][x]; + } + + if (x > 0 && y > 0) { + curr_val += (*this)[y - 1][x - 1]; + } + + if (curr_val != image_base[y][x]) { + LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]); + return false; + } + + if (GetRegionSum(x, y, x, y) != curr_val) { + LOGE("Mismatch!"); + } + } + } + + return true; + } + + // Returns the sum of all pixels in the specified region. + inline uint32 GetRegionSum(const int x1, const int y1, + const int x2, const int y2) const { + SCHECK(x1 >= 0 && y1 >= 0 && + x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(), + "indices out of bounds! %d-%d / %d, %d-%d / %d, ", + x1, x2, GetWidth(), y1, y2, GetHeight()); + + const uint32 everything = (*this)[y2][x2]; + + uint32 sum = everything; + if (x1 > 0 && y1 > 0) { + // Most common case. + const uint32 left = (*this)[y2][x1 - 1]; + const uint32 top = (*this)[y1 - 1][x2]; + const uint32 top_left = (*this)[y1 - 1][x1 - 1]; + + sum = everything - left - top + top_left; + SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d", + everything, left, top, top_left, sum, x1, y1, x2, y2); + } else if (x1 > 0) { + // Flush against top of image. + // Subtract out the region to the left only. + const uint32 top = (*this)[y2][x1 - 1]; + sum = everything - top; + SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum); + } else if (y1 > 0) { + // Flush against left side of image. + // Subtract out the region above only. + const uint32 left = (*this)[y1 - 1][x2]; + sum = everything - left; + SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum); + } + + SCHECK(sum >= 0, "Negative sum!"); + + return sum; + } + + // Returns the 2bit code associated with this region, which represents + // the overall gradient. + inline Code GetCode(const BoundingBox& bounding_box) const { + return GetCode(bounding_box.left_, bounding_box.top_, + bounding_box.right_, bounding_box.bottom_); + } + + inline Code GetCode(const int x1, const int y1, + const int x2, const int y2) const { + SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d", + x1, y1, x2, y2); + + // Gradient computed vertically. + const int box_height = (y2 - y1) / 2; + const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height); + const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2); + const bool vertical_code = top_sum > bottom_sum; + + // Gradient computed horizontally. + const int box_width = (x2 - x1) / 2; + const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2); + const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2); + const bool horizontal_code = left_sum > right_sum; + + const Code final_code = (vertical_code << 1) | horizontal_code; + + SCHECK(InRange(final_code, static_cast(0), static_cast(3)), + "Invalid code! %d", final_code); + + // Returns a value 0-3. + return final_code; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h new file mode 100644 index 0000000000..92458536b6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h @@ -0,0 +1,62 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ + +#include + +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +// The JniIntField class is used to access Java fields from native code. This +// technique of hiding pointers to native objects in opaque Java fields is how +// the Android hardware libraries work. This reduces the amount of static +// native methods and makes it easier to manage the lifetime of native objects. +class JniIntField { + public: + JniIntField(const char* field_name) : field_name_(field_name), field_ID_(0) {} + + int get(JNIEnv* env, jobject thiz) { + if (field_ID_ == 0) { + jclass cls = env->GetObjectClass(thiz); + CHECK_ALWAYS(cls != 0, "Unable to find class"); + field_ID_ = env->GetFieldID(cls, field_name_, "I"); + CHECK_ALWAYS(field_ID_ != 0, + "Unable to find field %s. (Check proguard cfg)", field_name_); + } + + return env->GetIntField(thiz, field_ID_); + } + + void set(JNIEnv* env, jobject thiz, int value) { + if (field_ID_ == 0) { + jclass cls = env->GetObjectClass(thiz); + CHECK_ALWAYS(cls != 0, "Unable to find class"); + field_ID_ = env->GetFieldID(cls, field_name_, "I"); + CHECK_ALWAYS(field_ID_ != 0, + "Unable to find field %s (Check proguard cfg)", field_name_); + } + + env->SetIntField(thiz, field_ID_, value); + } + + private: + const char* const field_name_; + + // This is just a cache + jfieldID field_ID_; +}; + +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h new file mode 100644 index 0000000000..82917261cb --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h @@ -0,0 +1,48 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" + +namespace tf_tracking { + +// For keeping track of keypoints. +struct Keypoint { + Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {} + Keypoint(const float x, const float y) + : pos_(x, y), score_(0.0f), type_(0) {} + + Point2f pos_; + float score_; + uint8 type_; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) { + return stream << "[" << keypoint.pos_ << ", " + << keypoint.score_ << ", " << keypoint.type_ << "]"; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc new file mode 100644 index 0000000000..6cc6b4e73f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc @@ -0,0 +1,549 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Various keypoint detecting functions. + +#include + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" + +namespace tf_tracking { + +static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) { + return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]); +} + +void KeypointDetector::ScoreKeypoints(const ImageData& image_data, + const int num_candidates, + Keypoint* const candidate_keypoints) { + const Image& I_x = *image_data.GetSpatialX(0); + const Image& I_y = *image_data.GetSpatialY(0); + + if (config_->detect_skin) { + const Image& u_data = *image_data.GetU(); + const Image& v_data = *image_data.GetV(); + + static const int reference[] = {111, 155}; + + // Score all the keypoints. + for (int i = 0; i < num_candidates; ++i) { + Keypoint* const keypoint = candidate_keypoints + i; + + const int x_pos = keypoint->pos_.x * 2; + const int y_pos = keypoint->pos_.y * 2; + + const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]}; + keypoint->score_ = + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) / + GetDistSquaredBetween(reference, curr_color); + } + } else { + // Score all the keypoints. + for (int i = 0; i < num_candidates; ++i) { + Keypoint* const keypoint = candidate_keypoints + i; + keypoint->score_ = + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y); + } + } +} + + +inline int KeypointCompare(const void* const a, const void* const b) { + return (reinterpret_cast(a)->score_ - + reinterpret_cast(b)->score_) <= 0 ? 1 : -1; +} + + +// Quicksorts detected keypoints by score. +void KeypointDetector::SortKeypoints(const int num_candidates, + Keypoint* const candidate_keypoints) const { + qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare); + +#ifdef SANITY_CHECKS + // Verify that the array got sorted. + float last_score = FLT_MAX; + for (int i = 0; i < num_candidates; ++i) { + const float curr_score = candidate_keypoints[i].score_; + + // Scores should be monotonically increasing. + SCHECK(last_score >= curr_score, + "Quicksort failure! %d: %.5f > %d: %.5f (%d total)", + i - 1, last_score, i, curr_score, num_candidates); + + last_score = curr_score; + } +#endif +} + + +int KeypointDetector::SelectKeypointsInBox( + const BoundingBox& box, + const Keypoint* const candidate_keypoints, + const int num_candidates, + const int max_keypoints, + const int num_existing_keypoints, + const Keypoint* const existing_keypoints, + Keypoint* const final_keypoints) const { + if (max_keypoints <= 0) { + return 0; + } + + // This is the distance within which keypoints may be placed to each other + // within this box, roughly based on the box dimensions. + const int distance = + MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f); + + // First, mark keypoints that already happen to be inside this region. Ignore + // keypoints that are outside it, however close they might be. + interest_map_->Clear(false); + for (int i = 0; i < num_existing_keypoints; ++i) { + const Keypoint& candidate = existing_keypoints[i]; + + const int x_pos = candidate.pos_.x; + const int y_pos = candidate.pos_.y; + if (box.Contains(candidate.pos_)) { + MarkImage(x_pos, y_pos, distance, interest_map_.get()); + } + } + + // Now, go through and check which keypoints will still fit in the box. + int num_keypoints_selected = 0; + for (int i = 0; i < num_candidates; ++i) { + const Keypoint& candidate = candidate_keypoints[i]; + + const int x_pos = candidate.pos_.x; + const int y_pos = candidate.pos_.y; + + if (!box.Contains(candidate.pos_) || + !interest_map_->ValidPixel(x_pos, y_pos)) { + continue; + } + + if (!(*interest_map_)[y_pos][x_pos]) { + final_keypoints[num_keypoints_selected++] = candidate; + if (num_keypoints_selected >= max_keypoints) { + break; + } + MarkImage(x_pos, y_pos, distance, interest_map_.get()); + } + } + return num_keypoints_selected; +} + + +void KeypointDetector::SelectKeypoints( + const std::vector& boxes, + const Keypoint* const candidate_keypoints, + const int num_candidates, + FramePair* const curr_change) const { + // Now select all the interesting keypoints that fall insider our boxes. + curr_change->number_of_keypoints_ = 0; + for (std::vector::const_iterator iter = boxes.begin(); + iter != boxes.end(); ++iter) { + const BoundingBox bounding_box = *iter; + + // Count up keypoints that have already been selected, and fall within our + // box. + int num_keypoints_already_in_box = 0; + for (int i = 0; i < curr_change->number_of_keypoints_; ++i) { + if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) { + ++num_keypoints_already_in_box; + } + } + + const int max_keypoints_to_find_in_box = + MIN(kMaxKeypointsForObject - num_keypoints_already_in_box, + kMaxKeypoints - curr_change->number_of_keypoints_); + + const int num_new_keypoints_in_box = SelectKeypointsInBox( + bounding_box, + candidate_keypoints, + num_candidates, + max_keypoints_to_find_in_box, + curr_change->number_of_keypoints_, + curr_change->frame1_keypoints_, + curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_); + + curr_change->number_of_keypoints_ += num_new_keypoints_in_box; + + LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_); + } +} + + +// Walks along the given circle checking for pixels above or below the center. +// Returns a score, or 0 if the keypoint did not pass the criteria. +// +// Parameters: +// circle_perimeter: the circumference in pixels of the circle. +// threshold: the minimum number of contiguous pixels that must be above or +// below the center value. +// center_ptr: the location of the center pixel in memory +// offsets: the relative offsets from the center pixel of the edge pixels. +inline int TestCircle(const int circle_perimeter, const int threshold, + const uint8* const center_ptr, + const int* offsets) { + // Get the actual value of the center pixel for easier reference later on. + const int center_value = static_cast(*center_ptr); + + // Number of total pixels to check. Have to wrap around some in case + // the contiguous section is split by the array edges. + const int num_total = circle_perimeter + threshold - 1; + + int num_above = 0; + int above_diff = 0; + + int num_below = 0; + int below_diff = 0; + + // Used to tell when this is definitely not going to meet the threshold so we + // can early abort. + int minimum_by_now = threshold - num_total + 1; + + // Go through every pixel along the perimeter of the circle, and then around + // again a little bit. + for (int i = 0; i < num_total; ++i) { + // This should be faster than mod. + const int perim_index = i < circle_perimeter ? i : i - circle_perimeter; + + // This gets the value of the current pixel along the perimeter by using + // a precomputed offset. + const int curr_value = + static_cast(center_ptr[offsets[perim_index]]); + + const int difference = curr_value - center_value; + + if (difference > kFastDiffAmount) { + above_diff += difference; + ++num_above; + + num_below = 0; + below_diff = 0; + + if (num_above >= threshold) { + return above_diff; + } + } else if (difference < -kFastDiffAmount) { + below_diff += difference; + ++num_below; + + num_above = 0; + above_diff = 0; + + if (num_below >= threshold) { + return below_diff; + } + } else { + num_above = 0; + num_below = 0; + above_diff = 0; + below_diff = 0; + } + + // See if there's any chance of making the threshold. + if (MAX(num_above, num_below) < minimum_by_now) { + // Didn't pass. + return 0; + } + ++minimum_by_now; + } + + // Didn't pass. + return 0; +} + + +// Returns a score in the range [0.0, positive infinity) which represents the +// relative likelihood of a point being a corner. +float KeypointDetector::HarrisFilter(const Image& I_x, + const Image& I_y, + const float x, const float y) const { + if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) && + I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) { + // Image gradient matrix. + float G[] = { 0, 0, 0, 0 }; + CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G); + + const float dx = G[0]; + const float dy = G[3]; + const float dxy = G[1]; + + // Harris-Nobel corner score. + return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN); + } + + return 0.0f; +} + + +int KeypointDetector::AddExtraCandidatesForBoxes( + const std::vector& boxes, + const int max_num_keypoints, + Keypoint* const keypoints) const { + int num_keypoints_added = 0; + + for (std::vector::const_iterator iter = boxes.begin(); + iter != boxes.end(); ++iter) { + const BoundingBox box = *iter; + + for (int i = 0; i < kNumToAddAsCandidates; ++i) { + for (int j = 0; j < kNumToAddAsCandidates; ++j) { + if (num_keypoints_added >= max_num_keypoints) { + LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints); + return num_keypoints_added; + } + + Keypoint curr_keypoint = keypoints[num_keypoints_added++]; + curr_keypoint.pos_ = Point2f( + box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates, + box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates); + curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST; + } + } + } + + return num_keypoints_added; +} + + +void KeypointDetector::FindKeypoints(const ImageData& image_data, + const std::vector& rois, + const FramePair& prev_change, + FramePair* const curr_change) { + // Copy keypoints from second frame of last pass to temp keypoints of this + // pass. + int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_); + + const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints; + number_of_tmp_keypoints += + FindFastKeypoints(image_data, max_num_fast, + tmp_keypoints_ + number_of_tmp_keypoints); + + TimeLog("Found FAST keypoints"); + + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { + LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints", + kMaxTempKeypoints, number_of_tmp_keypoints); + } + + if (kAddArbitraryKeypoints) { + // Add some for each object prior to scoring. + const int max_num_box_keypoints = + kMaxTempKeypoints - number_of_tmp_keypoints; + number_of_tmp_keypoints += + AddExtraCandidatesForBoxes(rois, max_num_box_keypoints, + tmp_keypoints_ + number_of_tmp_keypoints); + TimeLog("Added box keypoints"); + + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { + LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints", + kMaxTempKeypoints, number_of_tmp_keypoints); + } + } + + // Score them... + LOGV("Scoring %d keypoints!", number_of_tmp_keypoints); + ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_); + TimeLog("Scored keypoints"); + + // Now pare it down a bit. + SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_); + TimeLog("Sorted keypoints"); + + LOGV("%d keypoints to select from!", number_of_tmp_keypoints); + + SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change); + TimeLog("Selected keypoints"); + + LOGV("Picked %d (%d max) final keypoints out of %d potential.", + curr_change->number_of_keypoints_, + kMaxKeypoints, number_of_tmp_keypoints); +} + + +int KeypointDetector::CopyKeypoints(const FramePair& prev_change, + Keypoint* const new_keypoints) { + int number_of_keypoints = 0; + + // Caching values from last pass, just copy and compact. + for (int i = 0; i < prev_change.number_of_keypoints_; ++i) { + if (prev_change.optical_flow_found_keypoint_[i]) { + new_keypoints[number_of_keypoints] = + prev_change.frame2_keypoints_[i]; + + new_keypoints[number_of_keypoints].score_ = + prev_change.frame1_keypoints_[i].score_; + + ++number_of_keypoints; + } + } + + TimeLog("Copied keypoints"); + return number_of_keypoints; +} + + +// FAST keypoint detector. +int KeypointDetector::FindFastKeypoints(const Image& frame, + const int quadrant, + const int downsample_factor, + const int max_num_keypoints, + Keypoint* const keypoints) { + /* + // Reference for a circle of diameter 7. + const int circle[] = {0, 0, 1, 1, 1, 0, 0, + 0, 1, 0, 0, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 1, + 0, 1, 0, 0, 0, 1, 0, + 0, 0, 1, 1, 1, 0, 0}; + const int circle_offset[] = + {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46}; + */ + + // Quick test of compass directions. Any length 16 circle with a break of up + // to 4 pixels will have at least 3 of these 4 pixels active. + static const int short_circle_perimeter = 4; + static const int short_threshold = 3; + static const int short_circle_x[] = { -3, 0, +3, 0 }; + static const int short_circle_y[] = { 0, -3, 0, +3 }; + + // Precompute image offsets. + int short_offsets[short_circle_perimeter]; + for (int i = 0; i < short_circle_perimeter; ++i) { + short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth(); + } + + // Large circle values. + static const int full_circle_perimeter = 16; + static const int full_threshold = 12; + static const int full_circle_x[] = + { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 }; + static const int full_circle_y[] = + { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 }; + + // Precompute image offsets. + int full_offsets[full_circle_perimeter]; + for (int i = 0; i < full_circle_perimeter; ++i) { + full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth(); + } + + const int scratch_stride = frame.stride(); + + keypoint_scratch_->Clear(0); + + // Set up the bounds on the region to test based on the passed-in quadrant. + const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer; + const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer; + const int start_x = + kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width); + const int start_y = + kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height); + const int end_x = start_x + quadrant_width; + const int end_y = start_y + quadrant_height; + + // Loop through once to find FAST keypoint clumps. + for (int img_y = start_y; img_y < end_y; ++img_y) { + const uint8* curr_pixel_ptr = frame[img_y] + start_x; + + for (int img_x = start_x; img_x < end_x; ++img_x) { + // Only insert it if it meets the quick minimum requirements test. + if (TestCircle(short_circle_perimeter, short_threshold, + curr_pixel_ptr, short_offsets) != 0) { + // Longer test for actual keypoint score.. + const int fast_score = TestCircle(full_circle_perimeter, + full_threshold, + curr_pixel_ptr, + full_offsets); + + // Non-zero score means the keypoint was found. + if (fast_score != 0) { + uint8* const center_ptr = (*keypoint_scratch_)[img_y] + img_x; + + // Increase the keypoint count on this pixel and the pixels in all + // 4 cardinal directions. + *center_ptr += 5; + *(center_ptr - 1) += 1; + *(center_ptr + 1) += 1; + *(center_ptr - scratch_stride) += 1; + *(center_ptr + scratch_stride) += 1; + } + } + + ++curr_pixel_ptr; + } // x + } // y + + TimeLog("Found FAST keypoints."); + + int num_keypoints = 0; + // Loop through again and Harris filter pixels in the center of clumps. + // We can shrink the window by 1 pixel on every side. + for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) { + const uint8* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x; + + for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) { + if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) { + Keypoint* const keypoint = keypoints + num_keypoints; + keypoint->pos_ = Point2f( + img_x * downsample_factor, img_y * downsample_factor); + keypoint->score_ = 0; + keypoint->type_ = KEYPOINT_TYPE_FAST; + + ++num_keypoints; + if (num_keypoints >= max_num_keypoints) { + return num_keypoints; + } + } + + ++curr_pixel_ptr; + } // x + } // y + + TimeLog("Picked FAST keypoints."); + + return num_keypoints; +} + +int KeypointDetector::FindFastKeypoints(const ImageData& image_data, + const int max_num_keypoints, + Keypoint* const keypoints) { + int downsample_factor = 1; + int num_found = 0; + + // TODO(andrewharp): Get this working for multiple image scales. + for (int i = 0; i < 1; ++i) { + const Image& frame = *image_data.GetPyramidSqrt2Level(i); + num_found += FindFastKeypoints( + frame, fast_quadrant_, + downsample_factor, max_num_keypoints, keypoints + num_found); + downsample_factor *= 2; + } + + // Increment the current quadrant. + fast_quadrant_ = (fast_quadrant_ + 1) % 4; + + return num_found; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h new file mode 100644 index 0000000000..6cdd5dde11 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h @@ -0,0 +1,133 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +using namespace tensorflow; + +namespace tf_tracking { + +struct Keypoint; + +class KeypointDetector { + public: + explicit KeypointDetector(const KeypointDetectorConfig* const config) + : config_(config), + keypoint_scratch_(new Image(config_->image_size)), + interest_map_(new Image(config_->image_size)), + fast_quadrant_(0) { + interest_map_->Clear(false); + } + + ~KeypointDetector() {} + + // Finds a new set of keypoints for the current frame, picked from the current + // set of keypoints and also from a set discovered via a keypoint detector. + // Special attention is applied to make sure that keypoints are distributed + // within the supplied ROIs. + void FindKeypoints(const ImageData& image_data, + const std::vector& rois, + const FramePair& prev_change, + FramePair* const curr_change); + + private: + // Compute the corneriness of a point in the image. + float HarrisFilter(const Image& I_x, const Image& I_y, + const float x, const float y) const; + + // Adds a grid of candidate keypoints to the given box, up to + // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower. + int AddExtraCandidatesForBoxes( + const std::vector& boxes, + const int max_num_keypoints, + Keypoint* const keypoints) const; + + // Scan the frame for potential keypoints using the FAST keypoint detector. + // Quadrant is an argument 0-3 which refers to the quadrant of the image in + // which to detect keypoints. + int FindFastKeypoints(const Image& frame, + const int quadrant, + const int downsample_factor, + const int max_num_keypoints, + Keypoint* const keypoints); + + int FindFastKeypoints(const ImageData& image_data, + const int max_num_keypoints, + Keypoint* const keypoints); + + // Score a bunch of candidate keypoints. Assigns the scores to the input + // candidate_keypoints array entries. + void ScoreKeypoints(const ImageData& image_data, + const int num_candidates, + Keypoint* const candidate_keypoints); + + void SortKeypoints(const int num_candidates, + Keypoint* const candidate_keypoints) const; + + // Selects a set of keypoints falling within the supplied box such that the + // most highly rated keypoints are picked first, and so that none of them are + // too close together. + int SelectKeypointsInBox( + const BoundingBox& box, + const Keypoint* const candidate_keypoints, + const int num_candidates, + const int max_keypoints, + const int num_existing_keypoints, + const Keypoint* const existing_keypoints, + Keypoint* const final_keypoints) const; + + // Selects from the supplied sorted keypoint pool a set of keypoints that will + // best cover the given set of boxes, such that each box is covered at a + // resolution proportional to its size. + void SelectKeypoints( + const std::vector& boxes, + const Keypoint* const candidate_keypoints, + const int num_candidates, + FramePair* const frame_change) const; + + // Copies and compacts the found keypoints in the second frame of prev_change + // into the array at new_keypoints. + static int CopyKeypoints(const FramePair& prev_change, + Keypoint* const new_keypoints); + + const KeypointDetectorConfig* const config_; + + // Scratch memory for keypoint candidacy detection and non-max suppression. + std::unique_ptr > keypoint_scratch_; + + // Regions of the image to pay special attention to. + std::unique_ptr > interest_map_; + + // The current quadrant of the image to detect FAST keypoints in. + // Keypoint detection is staggered for performance reasons. Every four frames + // a full scan of the frame will have been performed. + int fast_quadrant_; + + Keypoint tmp_keypoints_[kMaxTempKeypoints]; +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/log_streaming.h b/tensorflow/examples/android/jni/object_tracking/log_streaming.h new file mode 100644 index 0000000000..e68945cc72 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/log_streaming.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ + +#include +#include + +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +using namespace tensorflow; + +namespace tf_tracking { + +#define LOGV(...) +#define LOGD(...) +#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); +#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); +#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.cc b/tensorflow/examples/android/jni/object_tracking/object_detector.cc new file mode 100644 index 0000000000..7f65716fdf --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.cc @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" + +namespace tf_tracking { + +// This is here so that the vtable gets created properly. +ObjectDetectorBase::~ObjectDetectorBase() {} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h new file mode 100644 index 0000000000..043f606e1d --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h @@ -0,0 +1,232 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +// Defines the ObjectDetector class that is the main interface for detecting +// ObjectModelBases in frames. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" + +namespace tf_tracking { + +// Adds BoundingSquares to a vector such that the first square added is centered +// in the position given and of square_size, and the remaining squares are added +// concentrentically, scaling down by scale_factor until the minimum threshold +// size is passed. +// Squares that do not fall completely within image_bounds will not be added. +static inline void FillWithSquares( + const BoundingBox& image_bounds, + const BoundingBox& position, + const float starting_square_size, + const float smallest_square_size, + const float scale_factor, + std::vector* const squares) { + BoundingSquare descriptor_area = + GetCenteredSquare(position, starting_square_size); + + SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor); + + // Use a do/while loop to ensure that at least one descriptor is created. + do { + if (image_bounds.Contains(descriptor_area.ToBoundingBox())) { + squares->push_back(descriptor_area); + } + descriptor_area.Scale(scale_factor); + } while (descriptor_area.size_ >= smallest_square_size - EPSILON); + LOGV("Created %zu squares starting from size %.2f to min size %.2f " + "using scale factor: %.2f", + squares->size(), starting_square_size, smallest_square_size, + scale_factor); +} + + +// Represents a potential detection of a specific ObjectExemplar and Descriptor +// at a specific position in the image. +class Detection { + public: + explicit Detection(const ObjectModelBase* const object_model, + const MatchScore match_score, + const BoundingBox& bounding_box) + : object_model_(object_model), + match_score_(match_score), + bounding_box_(bounding_box) {} + + Detection(const Detection& other) + : object_model_(other.object_model_), + match_score_(other.match_score_), + bounding_box_(other.bounding_box_) {} + + virtual ~Detection() {} + + inline BoundingBox GetObjectBoundingBox() const { + return bounding_box_; + } + + inline MatchScore GetMatchScore() const { + return match_score_; + } + + inline const ObjectModelBase* GetObjectModel() const { + return object_model_; + } + + inline bool Intersects(const Detection& other) { + // Check if any of the four axes separates us, there must be at least one. + return bounding_box_.Intersects(other.bounding_box_); + } + + struct Comp { + inline bool operator()(const Detection& a, const Detection& b) const { + return a.match_score_ > b.match_score_; + } + }; + + // TODO(andrewharp): add accessors to update these instead. + const ObjectModelBase* object_model_; + MatchScore match_score_; + BoundingBox bounding_box_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const Detection& detection) { + const BoundingBox actual_area = detection.GetObjectBoundingBox(); + stream << actual_area; + return stream; +} + +class ObjectDetectorBase { + public: + explicit ObjectDetectorBase(const ObjectDetectorConfig* const config) + : config_(config), + image_data_(NULL) {} + + virtual ~ObjectDetectorBase(); + + // Sets the current image data. All calls to ObjectDetector other than + // FillDescriptors use the image data last set. + inline void SetImageData(const ImageData* const image_data) { + image_data_ = image_data; + } + + // Main entry point into the detection algorithm. + // Scans the frame for candidates, tweaks them, and fills in the + // given std::vector of Detection objects with acceptable matches. + virtual void Detect(const std::vector& positions, + std::vector* const detections) const = 0; + + virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0; + + virtual void DeleteObjectModel(const std::string& name) = 0; + + virtual void GetObjectModels( + std::vector* models) const = 0; + + // Creates a new ObjectExemplar from the given position in the context of + // the last frame passed to NextFrame. + // Will return null in the case that there's no room for a descriptor to be + // created in the example area, or the example area is not completely + // contained within the frame. + virtual void UpdateModel( + const Image& base_image, + const IntegralImage& integral_image, + const BoundingBox& bounding_box, + const bool locked, + ObjectModelBase* model) const = 0; + + virtual void Draw() const = 0; + + virtual bool AllowSpontaneousDetections() = 0; + + protected: + const std::unique_ptr config_; + + // The latest frame data, upon which all detections will be performed. + // Not owned by this object, just provided for reference by ObjectTracker + // via SetImageData(). + const ImageData* image_data_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase); +}; + +template +class ObjectDetector : public ObjectDetectorBase { + public: + explicit ObjectDetector(const ObjectDetectorConfig* const config) + : ObjectDetectorBase(config) {} + + virtual ~ObjectDetector() { + typename std::map::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + ModelType* model = it->second; + delete model; + } + } + + virtual void DeleteObjectModel(const std::string& name) { + ModelType* model = object_models_[name]; + CHECK_ALWAYS(model != NULL, "Model was null!"); + object_models_.erase(name); + SAFE_DELETE(model); + } + + virtual void GetObjectModels( + std::vector* models) const { + typename std::map::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + models->push_back(it->second); + } + } + + virtual bool AllowSpontaneousDetections() { + return false; + } + + protected: + std::map object_models_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h new file mode 100644 index 0000000000..2d359668b2 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_model.h @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +// Contains ObjectModelBase declaration. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ + +#ifdef __RENDER_OPENGL__ +#include +#include +#endif + +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +namespace tf_tracking { + +// The ObjectModelBase class represents all the known appearance information for +// an object. It is not a specific instance of the object in the world, +// but just the general appearance information that enables detection. An +// ObjectModelBase can be reused across multiple-instances of TrackedObjects. +class ObjectModelBase { + public: + ObjectModelBase(const std::string& name) : name_(name) {} + + virtual ~ObjectModelBase() {} + + // Called when the next step in an ongoing track occurs. + virtual void TrackStep( + const BoundingBox& position, const Image& image, + const IntegralImage& integral_image, const bool authoritative) {} + + // Called when an object track is lost. + virtual void TrackLost() {} + + // Called when an object track is confirmed as legitimate. + virtual void TrackConfirmed() {} + + virtual float GetMaxCorrelation(const Image& patch_image) const = 0; + + virtual MatchScore GetMatchScore( + const BoundingBox& position, const ImageData& image_data) const = 0; + + virtual void Draw(float* const depth) const = 0; + + inline const std::string& GetName() const { + return name_; + } + + protected: + const std::string name_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase); +}; + +template +class ObjectModel : public ObjectModelBase { + public: + ObjectModel(const DetectorType* const detector, + const std::string& name) + : ObjectModelBase(name), detector_(detector) {} + + protected: + const DetectorType* const detector_; + + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc new file mode 100644 index 0000000000..1d867b934b --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc @@ -0,0 +1,690 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef __RENDER_OPENGL__ +#include +#include +#endif + +#include +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +ObjectTracker::ObjectTracker(const TrackerConfig* const config, + ObjectDetectorBase* const detector) + : config_(config), + frame_width_(config->image_size.width), + frame_height_(config->image_size.height), + curr_time_(0), + num_frames_(0), + flow_cache_(&config->flow_config), + keypoint_detector_(&config->keypoint_detector_config), + curr_num_frame_pairs_(0), + first_frame_index_(0), + frame1_(new ImageData(frame_width_, frame_height_)), + frame2_(new ImageData(frame_width_, frame_height_)), + detector_(detector), + num_detected_(0) { + for (int i = 0; i < kNumFrames; ++i) { + frame_pairs_[i].Init(-1, -1); + } +} + + +ObjectTracker::~ObjectTracker() { + for (TrackedObjectMap::iterator iter = objects_.begin(); + iter != objects_.end(); iter++) { + TrackedObject* object = iter->second; + SAFE_DELETE(object); + } +} + + +// Finds the correspondences for all the points in the current pair of frames. +// Stores the results in the given FramePair. +void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const { + // Keypoints aren't found until they're found. + memset(frame_pair->optical_flow_found_keypoint_, false, + sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints); + TimeLog("Cleared old found keypoints"); + + int num_keypoints_found = 0; + + // For every keypoint... + for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) { + Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat; + Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat; + + if (flow_cache_.FindNewPositionOfPoint( + keypoint1->pos_.x, keypoint1->pos_.y, + &keypoint2->pos_.x, &keypoint2->pos_.y)) { + frame_pair->optical_flow_found_keypoint_[i_feat] = true; + ++num_keypoints_found; + } + } + + TimeLog("Found correspondences"); + + LOGV("Found %d of %d keypoint correspondences", + num_keypoints_found, frame_pair->number_of_keypoints_); +} + + +void ObjectTracker::NextFrame(const uint8* const new_frame, + const uint8* const uv_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3) { + IncrementFrameIndex(); + LOGV("Received frame %d", num_frames_); + + FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0); + curr_change->Init(curr_time_, timestamp); + + CHECK_ALWAYS(curr_time_ < timestamp, + "Timestamp must monotonically increase! Went from %lld to %lld" + " on frame %d.", + curr_time_, timestamp, num_frames_); + curr_time_ = timestamp; + + // Swap the frames. + frame1_.swap(frame2_); + + frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1); + + if (detector_.get() != NULL) { + detector_->SetImageData(frame2_.get()); + } + + flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3); + + if (num_frames_ == 1) { + // This must be the first frame, so abort. + return; + } + + if (config_->always_track || objects_.size() > 0) { + LOGV("Tracking %zu targets", objects_.size()); + ComputeKeypoints(true); + TimeLog("Keypoints computed!"); + + FindCorrespondences(curr_change); + TimeLog("Flow computed!"); + + TrackObjects(); + } + TimeLog("Targets tracked!"); + + if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) { + DetectTargets(); + } + TimeLog("Detected objects."); +} + + +TrackedObject* ObjectTracker::MaybeAddObject( + const std::string& id, + const Image& source_image, + const BoundingBox& bounding_box, + const ObjectModelBase* object_model) { + // Train the detector if this is a new object. + if (objects_.find(id) != objects_.end()) { + return objects_[id]; + } + + // Need to get a non-const version of the model, or create a new one if it + // wasn't given. + ObjectModelBase* model = NULL; + if (detector_ != NULL) { + // If a detector is registered, then this new object must have a model. + CHECK_ALWAYS(object_model != NULL, "No model given!"); + model = detector_->CreateObjectModel(object_model->GetName()); + } + TrackedObject* const object = + new TrackedObject(id, source_image, bounding_box, model); + + objects_[id] = object; + return object; +} + + +void ObjectTracker::RegisterNewObjectWithAppearance( + const std::string& id, const uint8* const new_frame, + const BoundingBox& bounding_box) { + ObjectModelBase* object_model = NULL; + + Image image(frame_width_, frame_height_); + image.FromArray(new_frame, frame_width_, 1); + + if (detector_ != NULL) { + object_model = detector_->CreateObjectModel(id); + CHECK_ALWAYS(object_model != NULL, "Null object model!"); + + const IntegralImage integral_image(image); + object_model->TrackStep(bounding_box, image, integral_image, true); + } + + // Create an object at this position. + CHECK_ALWAYS(!HaveObject(id), "Already have this object!"); + if (objects_.find(id) == objects_.end()) { + TrackedObject* const object = + MaybeAddObject(id, image, bounding_box, object_model); + CHECK_ALWAYS(object != NULL, "Object not created!"); + } +} + + +void ObjectTracker::SetPreviousPositionOfObject(const std::string& id, + const BoundingBox& bounding_box, + const int64 timestamp) { + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp); + CHECK_ALWAYS(timestamp <= curr_time_, + "Timestamp too great! %lld vs %lld", timestamp, curr_time_); + + TrackedObject* const object = GetObject(id); + + // Track this bounding box from the past to the current time. + const BoundingBox current_position = TrackBox(bounding_box, timestamp); + + object->UpdatePosition(current_position, curr_time_, *frame2_, false); + + VLOG(2) << "Set tracked position for " << id << " to " << bounding_box + << std::endl; +} + + +void ObjectTracker::SetCurrentPositionOfObject( + const std::string& id, const BoundingBox& bounding_box) { + SetPreviousPositionOfObject(id, bounding_box, curr_time_); +} + + +void ObjectTracker::ForgetTarget(const std::string& id) { + LOGV("Forgetting object %s", id.c_str()); + TrackedObject* const object = GetObject(id); + delete object; + objects_.erase(id); + + if (detector_ != NULL) { + detector_->DeleteObjectModel(id); + } +} + + +int ObjectTracker::GetKeypointsPacked(uint16* const out_data, + const float scale) const { + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; + uint16* curr_data = out_data; + int num_keypoints = 0; + + for (int i = 0; i < change.number_of_keypoints_; ++i) { + if (change.optical_flow_found_keypoint_[i]) { + ++num_keypoints; + const Point2f& point1 = change.frame1_keypoints_[i].pos_; + *curr_data++ = RealToFixed115(point1.x * scale); + *curr_data++ = RealToFixed115(point1.y * scale); + + const Point2f& point2 = change.frame2_keypoints_[i].pos_; + *curr_data++ = RealToFixed115(point2.x * scale); + *curr_data++ = RealToFixed115(point2.y * scale); + } + } + + return num_keypoints; +} + + +int ObjectTracker::GetKeypoints(const bool only_found, + float* const out_data) const { + int curr_keypoint = 0; + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; + + for (int i = 0; i < change.number_of_keypoints_; ++i) { + if (!only_found || change.optical_flow_found_keypoint_[i]) { + const int base = curr_keypoint * kKeypointStep; + out_data[base + 0] = change.frame1_keypoints_[i].pos_.x; + out_data[base + 1] = change.frame1_keypoints_[i].pos_.y; + + out_data[base + 2] = + change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f; + out_data[base + 3] = change.frame2_keypoints_[i].pos_.x; + out_data[base + 4] = change.frame2_keypoints_[i].pos_.y; + + out_data[base + 5] = change.frame1_keypoints_[i].score_; + out_data[base + 6] = change.frame1_keypoints_[i].type_; + ++curr_keypoint; + } + } + + LOGV("Got %d keypoints.", curr_keypoint); + + return curr_keypoint; +} + + +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, + const FramePair& frame_pair) const { + float translation_x; + float translation_y; + + float scale_x; + float scale_y; + + BoundingBox tracked_box(region); + frame_pair.AdjustBox( + tracked_box, &translation_x, &translation_y, &scale_x, &scale_y); + + tracked_box.Shift(Point2f(translation_x, translation_y)); + + if (scale_x > 0 && scale_y > 0) { + tracked_box.Scale(scale_x, scale_y); + } + return tracked_box; +} + + +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, + const int64 timestamp) const { + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp); + CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!"); + + // Anything that ended before the requested timestamp is of no concern to us. + bool found_it = false; + int num_frames_back = -1; + for (int i = 0; i < curr_num_frame_pairs_; ++i) { + const FramePair& frame_pair = + frame_pairs_[GetNthIndexFromEnd(i)]; + + if (frame_pair.end_time_ <= timestamp) { + num_frames_back = i - 1; + + if (num_frames_back > 0) { + LOGV("Went %d out of %d frames before finding frame. (index: %d)", + num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i)); + } + + found_it = true; + break; + } + } + + if (!found_it) { + LOGW("History did not go back far enough! %lld vs %lld", + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - + frame_pairs_[GetNthIndexFromStart(0)].end_time_, + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp); + } + + // Loop over all the frames in the queue, tracking the accumulated delta + // of the point from frame to frame. It's possible the point could + // go out of frame, but keep tracking as best we can, using points near + // the edge of the screen where it went out of bounds. + BoundingBox tracked_box(region); + for (int i = num_frames_back; i >= 0; --i) { + const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)]; + SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!"); + tracked_box = TrackBox(tracked_box, frame_pair); + } + return tracked_box; +} + + +// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4 +// 3d transformation matrix. +inline void Convert3x3To4x4( + const float* const in_matrix, float* const out_matrix) { + // X + out_matrix[0] = in_matrix[0]; + out_matrix[1] = in_matrix[3]; + out_matrix[2] = 0.0f; + out_matrix[3] = 0.0f; + + // Y + out_matrix[4] = in_matrix[1]; + out_matrix[5] = in_matrix[4]; + out_matrix[6] = 0.0f; + out_matrix[7] = 0.0f; + + // Z + out_matrix[8] = 0.0f; + out_matrix[9] = 0.0f; + out_matrix[10] = 1.0f; + out_matrix[11] = 0.0f; + + // Translation + out_matrix[12] = in_matrix[2]; + out_matrix[13] = in_matrix[5]; + out_matrix[14] = 0.0f; + out_matrix[15] = 1.0f; +} + + +void ObjectTracker::Draw(const int canvas_width, const int canvas_height, + const float* const frame_to_canvas) const { +#ifdef __RENDER_OPENGL__ + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + + glMatrixMode(GL_PROJECTION); + glLoadIdentity(); + + glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f); + + // To make Y go the right direction (0 at top of frame). + glScalef(1.0f, -1.0f, 1.0f); + glTranslatef(0.0f, -canvas_height, 0.0f); + + glMatrixMode(GL_MODELVIEW); + glLoadIdentity(); + + glPushMatrix(); + + // Apply the frame to canvas transformation. + static GLfloat transformation[16]; + Convert3x3To4x4(frame_to_canvas, transformation); + glMultMatrixf(transformation); + + // Draw tracked object bounding boxes. + for (TrackedObjectMap::const_iterator iter = objects_.begin(); + iter != objects_.end(); ++iter) { + TrackedObject* tracked_object = iter->second; + tracked_object->Draw(); + } + + static const bool kRenderDebugPyramid = false; + if (kRenderDebugPyramid) { + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); + for (int i = 0; i < kNumPyramidLevels * 2; ++i) { + Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw(); + } + } + + static const bool kRenderDebugDerivative = false; + if (kRenderDebugDerivative) { + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); + for (int i = 0; i < kNumPyramidLevels; ++i) { + const Image& dx = *frame1_->GetSpatialX(i); + Image render_image(dx.GetWidth(), dx.GetHeight()); + for (int y = 0; y < dx.GetHeight(); ++y) { + const int32* dx_ptr = dx[y]; + uint8* dst_ptr = render_image[y]; + for (int x = 0; x < dx.GetWidth(); ++x) { + *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255); + } + } + + Sprite(render_image).Draw(); + } + } + + if (detector_ != NULL) { + glDisable(GL_CULL_FACE); + detector_->Draw(); + } + glPopMatrix(); +#endif +} + +static void AddQuadrants(const BoundingBox& box, + std::vector* boxes) { + const Point2f center = box.GetCenter(); + + float x1 = box.left_; + float x2 = center.x; + float x3 = box.right_; + + float y1 = box.top_; + float y2 = center.y; + float y3 = box.bottom_; + + // Upper left. + boxes->push_back(BoundingBox(x1, y1, x2, y2)); + + // Upper right. + boxes->push_back(BoundingBox(x2, y1, x3, y2)); + + // Bottom left. + boxes->push_back(BoundingBox(x1, y2, x2, y3)); + + // Bottom right. + boxes->push_back(BoundingBox(x2, y2, x3, y3)); + + // Whole thing. + boxes->push_back(box); +} + +void ObjectTracker::ComputeKeypoints(const bool cached_ok) { + const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)]; + FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)]; + + std::vector boxes; + + for (TrackedObjectMap::iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + BoundingBox box = object_iter->second->GetPosition(); + box.Scale(config_->object_box_scale_factor_for_features, + config_->object_box_scale_factor_for_features); + AddQuadrants(box, &boxes); + } + + AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes); + + keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change); +} + + +// Given a vector of detections and a model, simply returns the Detection for +// that model with the highest correlation. +bool ObjectTracker::GetBestObjectForDetection( + const Detection& detection, TrackedObject** match) const { + TrackedObject* best_match = NULL; + float best_overlap = -FLT_MAX; + + LOGV("Looking for matches in %zu objects!", objects_.size()); + for (TrackedObjectMap::const_iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + TrackedObject* const tracked_object = object_iter->second; + + const float overlap = tracked_object->GetPosition().PascalScore( + detection.GetObjectBoundingBox()); + + if (!detector_->AllowSpontaneousDetections() && + (detection.GetObjectModel() != tracked_object->GetModel())) { + if (overlap > 0.0f) { + return false; + } + continue; + } + + const float jump_distance = + (tracked_object->GetPosition().GetCenter() - + detection.GetObjectBoundingBox().GetCenter()).LengthSquared(); + + const float allowed_distance = + tracked_object->GetAllowableDistanceSquared(); + + LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f", + jump_distance, allowed_distance, overlap); + + // TODO(andrewharp): No need to do this verification twice, eliminate + // one of the score checks (the other being in OnDetection). + if (jump_distance < allowed_distance && + overlap > best_overlap && + tracked_object->GetMatchScore() + kMatchScoreBuffer < + detection.GetMatchScore()) { + best_match = tracked_object; + best_overlap = overlap; + } else if (overlap > 0.0f) { + return false; + } + } + + *match = best_match; + return true; +} + + +void ObjectTracker::ProcessDetections( + std::vector* const detections) { + LOGV("Initial detection done, iterating over %zu detections now.", + detections->size()); + + const bool spontaneous_detections_allowed = + detector_->AllowSpontaneousDetections(); + for (std::vector::const_iterator it = detections->begin(); + it != detections->end(); ++it) { + const Detection& detection = *it; + SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()), + "Frame does not contain bounding box!"); + + TrackedObject* best_match = NULL; + + const bool no_collisions = + GetBestObjectForDetection(detection, &best_match); + + // Need to get a non-const version of the model, or create a new one if it + // wasn't given. + ObjectModelBase* model = + const_cast(detection.GetObjectModel()); + + if (best_match != NULL) { + if (model != best_match->GetModel()) { + CHECK_ALWAYS(detector_->AllowSpontaneousDetections(), + "Model for object changed but spontaneous detections not allowed!"); + } + best_match->OnDetection(model, + detection.GetObjectBoundingBox(), + detection.GetMatchScore(), + curr_time_, *frame2_); + } else if (no_collisions && spontaneous_detections_allowed) { + if (detection.GetMatchScore() > kMinimumMatchScore) { + LOGV("No match, adding it!"); + const ObjectModelBase* model = detection.GetObjectModel(); + std::ostringstream ss; + // TODO(andrewharp): Generate this in a more general fashion. + ss << "hand_" << num_detected_++; + std::string object_name = ss.str(); + MaybeAddObject(object_name, *frame2_->GetImage(), + detection.GetObjectBoundingBox(), model); + } + } + } +} + + +void ObjectTracker::DetectTargets() { + // Detect all object model types that we're currently tracking. + std::vector object_models; + detector_->GetObjectModels(&object_models); + if (object_models.size() == 0) { + LOGV("No objects to search for, aborting."); + return; + } + + LOGV("Trying to detect %zu models", object_models.size()); + + LOGV("Creating test vector!"); + std::vector positions; + + for (TrackedObjectMap::iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + TrackedObject* const tracked_object = object_iter->second; + +#if DEBUG_PREDATOR + positions.push_back(GetCenteredSquare( + frame2_->GetImage()->GetContainingBox(), 32.0f)); +#else + const BoundingBox& position = tracked_object->GetPosition(); + + const float square_size = MAX( + kScanMinSquareSize / (kLastKnownPositionScaleFactor * + kLastKnownPositionScaleFactor), + MIN(position.GetWidth(), + position.GetHeight())) / kLastKnownPositionScaleFactor; + + FillWithSquares(frame2_->GetImage()->GetContainingBox(), + tracked_object->GetPosition(), + square_size, + kScanMinSquareSize, + kLastKnownPositionScaleFactor, + &positions); + } +#endif + + LOGV("Created test vector!"); + + std::vector detections; + LOGV("Detecting!"); + detector_->Detect(positions, &detections); + LOGV("Found %zu detections", detections.size()); + + TimeLog("Finished detection."); + + ProcessDetections(&detections); + + TimeLog("iterated over detections"); + + LOGV("Done detecting!"); +} + + +void ObjectTracker::TrackObjects() { + // TODO(andrewharp): Correlation should be allowed to remove objects too. + const bool automatic_removal_allowed = detector_.get() != NULL ? + detector_->AllowSpontaneousDetections() : false; + + LOGV("Tracking %zu objects!", objects_.size()); + std::vector dead_objects; + for (TrackedObjectMap::iterator iter = objects_.begin(); + iter != objects_.end(); iter++) { + TrackedObject* object = iter->second; + const BoundingBox tracked_position = TrackBox( + object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]); + object->UpdatePosition(tracked_position, curr_time_, *frame2_, false); + + if (automatic_removal_allowed && + object->GetNumConsecutiveFramesBelowThreshold() > + kMaxNumDetectionFailures * 5) { + dead_objects.push_back(iter->first); + } + } + + if (detector_ != NULL && automatic_removal_allowed) { + for (std::vector::iterator iter = dead_objects.begin(); + iter != dead_objects.end(); iter++) { + LOGE("Removing object! %s", iter->c_str()); + ForgetTarget(*iter); + } + } + TimeLog("Tracked all objects."); + + LOGV("%zu objects tracked!", objects_.size()); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h new file mode 100644 index 0000000000..3d2a9af360 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h @@ -0,0 +1,271 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ + +#include +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" + +namespace tf_tracking { + +typedef std::map TrackedObjectMap; + +inline std::ostream& operator<<(std::ostream& stream, + const TrackedObjectMap& map) { + for (TrackedObjectMap::const_iterator iter = map.begin(); + iter != map.end(); ++iter) { + const TrackedObject& tracked_object = *iter->second; + const std::string& key = iter->first; + stream << key << ": " << tracked_object; + } + return stream; +} + + +// ObjectTracker is the highest-level class in the tracking/detection framework. +// It handles basic image processing, keypoint detection, keypoint tracking, +// object tracking, and object detection/relocalization. +class ObjectTracker { + public: + ObjectTracker(const TrackerConfig* const config, + ObjectDetectorBase* const detector); + virtual ~ObjectTracker(); + + virtual void NextFrame(const uint8* const new_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3) { + NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); + } + + // Called upon the arrival of a new frame of raw data. + // Does all image processing, keypoint detection, and object + // tracking/detection for registered objects. + // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that + // represents the main transformation that has happened between the last + // and the current frame. + // Argument align_level is the pyramid level (where 0 == finest) that + // the matrix is valid for. + virtual void NextFrame(const uint8* const new_frame, + const uint8* const uv_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3); + + virtual void RegisterNewObjectWithAppearance( + const std::string& id, const uint8* const new_frame, + const BoundingBox& bounding_box); + + // Updates the position of a tracked object, given that it was known to be at + // a certain position at some point in the past. + virtual void SetPreviousPositionOfObject(const std::string& id, + const BoundingBox& bounding_box, + const int64 timestamp); + + // Sets the current position of the object in the most recent frame provided. + virtual void SetCurrentPositionOfObject(const std::string& id, + const BoundingBox& bounding_box); + + // Tells the ObjectTracker to stop tracking a target. + void ForgetTarget(const std::string& id); + + // Fills the given out_data buffer with the latest detected keypoint + // correspondences, first scaled by scale_factor (to adjust for downsampling + // that may have occurred elsewhere), then packed in a fixed-point format. + int GetKeypointsPacked(uint16* const out_data, + const float scale_factor) const; + + // Copy the keypoint arrays after computeFlow is called. + // out_data should be at least kMaxKeypoints * kKeypointStep long. + // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, + // where N is the number of keypoints tracked. N is returned as the result. + int GetKeypoints(const bool only_found, float* const out_data) const; + + // Returns the current position of a box, given that it was at a certain + // position at the given time. + BoundingBox TrackBox(const BoundingBox& region, + const int64 timestamp) const; + + // Returns the number of frames that have been passed to NextFrame(). + inline int GetNumFrames() const { + return num_frames_; + } + + inline bool HaveObject(const std::string& id) const { + return objects_.find(id) != objects_.end(); + } + + // Returns the TrackedObject associated with the given id. + inline const TrackedObject* GetObject(const std::string& id) const { + TrackedObjectMap::const_iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + // Returns the TrackedObject associated with the given id. + inline TrackedObject* GetObject(const std::string& id) { + TrackedObjectMap::iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + bool IsObjectVisible(const std::string& id) const { + SCHECK(HaveObject(id), "Don't have this object."); + + const TrackedObject* object = GetObject(id); + return object->IsVisible(); + } + + virtual void Draw(const int canvas_width, const int canvas_height, + const float* const frame_to_canvas) const; + + protected: + // Creates a new tracked object at the given position. + // If an object model is provided, then that model will be associated with the + // object. If not, a new model may be created from the appearance at the + // initial position and registered with the object detector. + virtual TrackedObject* MaybeAddObject(const std::string& id, + const Image& image, + const BoundingBox& bounding_box, + const ObjectModelBase* object_model); + + // Find the keypoints in the frame before the current frame. + // If only one frame exists, keypoints will be found in that frame. + void ComputeKeypoints(const bool cached_ok = false); + + // Finds the correspondences for all the points in the current pair of frames. + // Stores the results in the given FramePair. + void FindCorrespondences(FramePair* const curr_change) const; + + inline int GetNthIndexFromEnd(const int offset) const { + return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); + } + + BoundingBox TrackBox(const BoundingBox& region, + const FramePair& frame_pair) const; + + inline void IncrementFrameIndex() { + // Move the current framechange index up. + ++num_frames_; + ++curr_num_frame_pairs_; + + // If we've got too many, push up the start of the queue. + if (curr_num_frame_pairs_ > kNumFrames) { + first_frame_index_ = GetNthIndexFromStart(1); + --curr_num_frame_pairs_; + } + } + + inline int GetNthIndexFromStart(const int offset) const { + SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, + "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); + return (first_frame_index_ + offset) % kNumFrames; + } + + void TrackObjects(); + + const std::unique_ptr config_; + + const int frame_width_; + const int frame_height_; + + int64 curr_time_; + + int num_frames_; + + TrackedObjectMap objects_; + + FlowCache flow_cache_; + + KeypointDetector keypoint_detector_; + + int curr_num_frame_pairs_; + int first_frame_index_; + + std::unique_ptr frame1_; + std::unique_ptr frame2_; + + FramePair frame_pairs_[kNumFrames]; + + std::unique_ptr detector_; + + int num_detected_; + + private: + void TrackTarget(TrackedObject* const object); + + bool GetBestObjectForDetection( + const Detection& detection, TrackedObject** match) const; + + void ProcessDetections(std::vector* const detections); + + void DetectTargets(); + + // Temp object used in ObjectTracker::CreateNewExample. + mutable std::vector squares; + + friend std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker); + + TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); +}; + +inline std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker) { + stream << "Frame size: " << tracker.frame_width_ << "x" + << tracker.frame_height_ << std::endl; + + stream << "Num frames: " << tracker.num_frames_ << std::endl; + + stream << "Curr time: " << tracker.curr_time_ << std::endl; + + const int first_frame_index = tracker.GetNthIndexFromStart(0); + const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; + + const int last_frame_index = tracker.GetNthIndexFromEnd(0); + const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; + + stream << "first frame: " << first_frame_index << "," + << first_frame_pair.end_time_ << " " + << "last frame: " << last_frame_index << "," + << last_frame_pair.end_time_ << " diff: " + << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" + << std::endl; + + stream << "Tracked targets:"; + stream << tracker.objects_; + + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc new file mode 100644 index 0000000000..30c5974654 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc @@ -0,0 +1,463 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" + +using namespace tensorflow; + +namespace tf_tracking { + +#define OBJECT_TRACKER_METHOD(METHOD_NAME) \ + Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT + +JniIntField object_tracker_field("nativeObjectTracker"); + +ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) { + ObjectTracker* const object_tracker = + reinterpret_cast(object_tracker_field.get(env, thiz)); + CHECK_ALWAYS(object_tracker != NULL, "null object tracker!"); + return object_tracker; +} + +void set_object_tracker(JNIEnv* env, jobject thiz, + const ObjectTracker* object_tracker) { + object_tracker_field.set(env, thiz, + reinterpret_cast(object_tracker)); +} + +#ifdef __cplusplus +extern "C" { +#endif +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, + jint width, jint height, + jboolean always_track); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, + jobject thiz); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jbyteArray frame_data); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jlong timestamp); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2); + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, + jbyteArray y_data, + jbyteArray uv_data, + jlong timestamp, + jfloatArray vg_matrix_2x3); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( + JNIEnv* env, jobject thiz, jfloat scale_factor); + +JNIEXPORT +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( + JNIEnv* env, jobject thiz, jboolean only_found_); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, + jfloat position_y1, jfloat position_x2, jfloat position_y2, + jfloatArray delta); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj, + jint view_width, + jint view_height, + jfloatArray delta); + +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, + jbyteArray input, jint factor, jbyteArray output); + +#ifdef __cplusplus +} +#endif + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, + jint width, jint height, + jboolean always_track) { + LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz); + const Size image_size(width, height); + TrackerConfig* const tracker_config = new TrackerConfig(image_size); + tracker_config->always_track = always_track; + + // XXX detector + ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL); + set_object_tracker(env, thiz, tracker); + LOGI("Initialized!"); + + CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker, + "Failure to set hand tracker!"); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, + jobject thiz) { + delete get_object_tracker(env, thiz); + set_object_tracker(env, thiz, NULL); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jbyteArray frame_data) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, + x2, y2); + + jboolean iCopied = JNI_FALSE; + + // Copy image into currFrame. + jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied); + + BoundingBox bounding_box(x1, y1, x2, y2); + get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance( + id_str, reinterpret_cast(pixels), bounding_box); + + env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jlong timestamp) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI( + "Registering the position of %s at %.2f,%.2f,%.2f,%.2f" + " at time %lld", + id_str, x1, y1, x2, y2, static_cast(timestamp)); + + get_object_tracker(env, thiz)->SetPreviousPositionOfObject( + id_str, BoundingBox(x1, y1, x2, y2), timestamp); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, + x2, y2); + + get_object_tracker(env, thiz)->SetCurrentPositionOfObject( + id_str, BoundingBox(x1, y1, x2, y2)); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + return haveObject; +} + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + return visible; +} + +JNIEXPORT +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + const TrackedObject* const object = + get_object_tracker(env, thiz)->GetObject(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str()); + return model_name; +} + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const float correlation = + get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation(); + env->ReleaseStringUTFChars(object_id, id_str); + return correlation; +} + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const float match_score = + get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value; + env->ReleaseStringUTFChars(object_id, id_str); + return match_score; +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) { + jboolean iCopied = JNI_FALSE; + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const BoundingBox bounding_box = + get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition(); + env->ReleaseStringUTFChars(object_id, id_str); + + jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied); + bounding_box.CopyToArray(reinterpret_cast(rect)); + env->ReleaseFloatArrayElements(rect_array, rect, 0); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, + jbyteArray y_data, + jbyteArray uv_data, + jlong timestamp, + jfloatArray vg_matrix_2x3) { + TimeLog("Starting object tracker"); + + jboolean iCopied = JNI_FALSE; + + float vision_gyro_matrix_array[6]; + jfloat* jmat = NULL; + + if (vg_matrix_2x3 != NULL) { + // Copy the alignment matrix into a float array. + jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied); + for (int i = 0; i < 6; ++i) { + vision_gyro_matrix_array[i] = static_cast(jmat[i]); + } + } + // Copy image into currFrame. + jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied); + jbyte* uv_pixels = + uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL; + + TimeLog("Got elements"); + + // Add the frame to the object tracker object. + get_object_tracker(env, thiz)->NextFrame( + reinterpret_cast(pixels), reinterpret_cast(uv_pixels), + timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL); + + env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT); + + if (uv_data != NULL) { + env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT); + } + + if (vg_matrix_2x3 != NULL) { + env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT); + } + + TimeLog("Released elements"); + + PrintTimeLog(); + ResetTimeLog(); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + get_object_tracker(env, thiz)->ForgetTarget(id_str); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( + JNIEnv* env, jobject thiz, jboolean only_found) { + jfloat keypoint_arr[kMaxKeypoints * kKeypointStep]; + + const int number_of_keypoints = + get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr); + + // Create and return the array that will be passed back to Java. + jfloatArray keypoints = + env->NewFloatArray(number_of_keypoints * kKeypointStep); + if (keypoints == NULL) { + LOGE("null array!"); + return NULL; + } + env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep, + keypoint_arr); + + return keypoints; +} + +JNIEXPORT +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( + JNIEnv* env, jobject thiz, jfloat scale_factor) { + // 2 bytes to a uint16 and two pairs of xy coordinates per keypoint. + const int bytes_per_keypoint = sizeof(uint16) * 2 * 2; + jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint]; + + const int number_of_keypoints = + get_object_tracker(env, thiz)->GetKeypointsPacked( + reinterpret_cast(keypoint_arr), scale_factor); + + // Create and return the array that will be passed back to Java. + jbyteArray keypoints = + env->NewByteArray(number_of_keypoints * bytes_per_keypoint); + + if (keypoints == NULL) { + LOGE("null array!"); + return NULL; + } + + env->SetByteArrayRegion( + keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr); + + return keypoints; +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, + jfloat position_y1, jfloat position_x2, jfloat position_y2, + jfloatArray delta) { + jfloat point_arr[4]; + + const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox( + BoundingBox(position_x1, position_y1, position_x2, position_y2), + timestamp); + + new_position.CopyToArray(point_arr); + env->SetFloatArrayRegion(delta, 0, 4, point_arr); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(drawNative)( + JNIEnv* env, jobject thiz, jint view_width, jint view_height, + jfloatArray frame_to_canvas_arr) { + ObjectTracker* object_tracker = get_object_tracker(env, thiz); + if (object_tracker != NULL) { + jfloat* frame_to_canvas = + env->GetFloatArrayElements(frame_to_canvas_arr, NULL); + + object_tracker->Draw(view_width, view_height, frame_to_canvas); + env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas, + JNI_ABORT); + } +} + +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, + jbyteArray input, jint factor, jbyteArray output) { + if (input == NULL || output == NULL) { + LOGW("Received null arrays, hopefully this is a test!"); + return; + } + + jbyte* const input_array = env->GetByteArrayElements(input, 0); + jbyte* const output_array = env->GetByteArrayElements(output, 0); + + { + tf_tracking::Image full_image( + width, height, reinterpret_cast(input_array), false); + + const int new_width = (width + factor - 1) / factor; + const int new_height = (height + factor - 1) / factor; + + tf_tracking::Image downsampled_image( + new_width, new_height, reinterpret_cast(output_array), false); + + downsampled_image.DownsampleAveraged(reinterpret_cast(input_array), + row_stride, factor); + } + + env->ReleaseByteArrayElements(input, input_array, JNI_ABORT); + env->ReleaseByteArrayElements(output, output_array, 0); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc new file mode 100644 index 0000000000..fab0a3155d --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc @@ -0,0 +1,490 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config) + : config_(config), + frame1_(NULL), + frame2_(NULL), + working_size_(config->image_size) {} + + +void OpticalFlow::NextFrame(const ImageData* const image_data) { + // Special case for the first frame: make sure the image ends up in + // frame1_ so that keypoint detection can be done on it if desired. + frame1_ = (frame1_ == NULL) ? image_data : frame2_; + frame2_ = image_data; +} + + +// Static heart of the optical flow computation. +// Lucas Kanade algorithm. +bool OpticalFlow::FindFlowAtPoint_LK(const Image& img_I, + const Image& img_J, + const Image& I_x, + const Image& I_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y) { + float g_x = *out_g_x; + float g_y = *out_g_y; + // Get values for frame 1. They remain constant through the inner + // iteration loop. + float vals_I[kFlowArraySize]; + float vals_I_x[kFlowArraySize]; + float vals_I_y[kFlowArraySize]; + + const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1; + const float kWindowSizeFloat = static_cast(kFlowIntegrationWindowSize); + +#if USE_FIXED_POINT_FLOW + const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1; + const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1; +#else + const float real_x_max = I_x.width_less_one_ - EPSILON; + const float real_y_max = I_x.height_less_one_ - EPSILON; +#endif + + // Get the window around the original point. + const float src_left_real = p_x - kWindowSizeFloat; + const float src_top_real = p_y - kWindowSizeFloat; + float* vals_I_ptr = vals_I; + float* vals_I_x_ptr = vals_I_x; + float* vals_I_y_ptr = vals_I_y; +#if USE_FIXED_POINT_FLOW + // Source integer coordinates. + const int src_left_fixed = RealToFixed1616(src_left_real); + const int src_top_fixed = RealToFixed1616(src_top_real); + + for (int y = 0; y < kPatchSize; ++y) { + const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max); + + for (int x = 0; x < kPatchSize; ++x) { + const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max); + + *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); + } + } +#else + for (int y = 0; y < kPatchSize; ++y) { + const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max); + + for (int x = 0; x < kPatchSize; ++x) { + const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max); + + *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos); + *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos); + *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos); + } + } +#endif + + // Compute the spatial gradient matrix about point p. + float G[] = { 0, 0, 0, 0 }; + CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G); + + // Find the inverse of G. + float G_inv[4]; + if (!Invert2x2(G, G_inv)) { + return false; + } + +#if NORMALIZE + const float mean_I = ComputeMean(vals_I, kFlowArraySize); + const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I); +#endif + + // Iterate kNumIterations times or until we converge. + for (int iteration = 0; iteration < kNumIterations; ++iteration) { + // Get values for frame 2. + float vals_J[kFlowArraySize]; + + // Get the window around the destination point. + const float left_real = p_x + g_x - kWindowSizeFloat; + const float top_real = p_y + g_y - kWindowSizeFloat; + float* vals_J_ptr = vals_J; +#if USE_FIXED_POINT_FLOW + // The top-left sub-pixel is set for the current iteration (in 16:16 + // fixed). This is constant over one iteration. + const int left_fixed = RealToFixed1616(left_real); + const int top_fixed = RealToFixed1616(top_real); + + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max); + for (int win_x = 0; win_x < kPatchSize; ++win_x) { + const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max); + *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y); + } + } +#else + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max); + for (int win_x = 0; win_x < kPatchSize; ++win_x) { + const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max); + *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos); + } + } +#endif + +#if NORMALIZE + const float mean_J = ComputeMean(vals_J, kFlowArraySize); + const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J); + + // TODO(andrewharp): Probably better to completely detect and handle the + // "corner case" where the patch is fully outside the image diagonally. + const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f; +#endif + + // Compute image mismatch vector. + float b_x = 0.0f; + float b_y = 0.0f; + + vals_I_ptr = vals_I; + vals_J_ptr = vals_J; + vals_I_x_ptr = vals_I_x; + vals_I_y_ptr = vals_I_y; + + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + for (int win_x = 0; win_x < kPatchSize; ++win_x) { +#if NORMALIZE + // Normalized Image difference. + const float dI = + (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio; +#else + const float dI = *vals_I_ptr++ - *vals_J_ptr++; +#endif + b_x += dI * *vals_I_x_ptr++; + b_y += dI * *vals_I_y_ptr++; + } + } + + // Optical flow... solve n = G^-1 * b + const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y); + const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y); + + // Update best guess with residual displacement from this level and + // iteration. + g_x += n_x; + g_y += n_y; + + // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y); + + // Abort early if we're already below the threshold. + if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) { + break; + } + } // Iteration. + + // Copy value back into output. + *out_g_x = g_x; + *out_g_y = g_y; + return true; +} + + +// Pointwise flow using translational 2dof ESM. +bool OpticalFlow::FindFlowAtPoint_ESM(const Image& img_I, + const Image& img_J, + const Image& I_x, + const Image& I_y, + const Image& J_x, + const Image& J_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y) { + float g_x = *out_g_x; + float g_y = *out_g_y; + const float area_inv = 1.0f / static_cast(kFlowArraySize); + + // Get values for frame 1. They remain constant through the inner + // iteration loop. + uint8 vals_I[kFlowArraySize]; + uint8 vals_J[kFlowArraySize]; + int16 src_gradient_x[kFlowArraySize]; + int16 src_gradient_y[kFlowArraySize]; + + // TODO(rspring): try out the IntegerPatchAlign() method once + // the code for that is in ../common. + const float wsize_float = static_cast(kFlowIntegrationWindowSize); + const int src_left_fixed = RealToFixed1616(p_x - wsize_float); + const int src_top_fixed = RealToFixed1616(p_y - wsize_float); + const int patch_size = 2 * kFlowIntegrationWindowSize + 1; + + // Create the keypoint template patch from a subpixel location. + if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, vals_I) || + !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, + src_gradient_x) || + !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, + src_gradient_y)) { + return false; + } + + int bright_offset = 0; + int sum_diff = 0; + + // The top-left sub-pixel is set for the current iteration (in 16:16 fixed). + // This is constant over one iteration. + int left_fixed = RealToFixed1616(p_x + g_x - wsize_float); + int top_fixed = RealToFixed1616(p_y + g_y - wsize_float); + + // The truncated version gives the most top-left pixel that is used. + int left_trunc = left_fixed >> 16; + int top_trunc = top_fixed >> 16; + + // Compute an initial brightness offset. + if (kDoBrightnessNormalize && + left_trunc >= 0 && top_trunc >= 0 && + (left_trunc + patch_size) < img_J.width_less_one_ && + (top_trunc + patch_size) < img_J.height_less_one_) { + int templ_index = 0; + const uint8* j_row = img_J[top_trunc] + left_trunc; + + const int j_stride = img_J.stride(); + + for (int y = 0; y < patch_size; ++y, j_row += j_stride) { + for (int x = 0; x < patch_size; ++x) { + sum_diff += static_cast(j_row[x]) - vals_I[templ_index++]; + } + } + + bright_offset = static_cast(static_cast(sum_diff) * area_inv); + } + + // Iterate kNumIterations times or until we go out of image. + for (int iteration = 0; iteration < kNumIterations; ++iteration) { + int jtj[3] = { 0, 0, 0 }; + int jtr[2] = { 0, 0 }; + sum_diff = 0; + + // Extract the target image values. + // Extract the gradient from the target image patch and accumulate to + // the gradient of the source image patch. + if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed, + patch_size, patch_size, + vals_J)) { + break; + } + + const uint8* templ_row = vals_I; + const uint8* extract_row = vals_J; + const int16* src_dx_row = src_gradient_x; + const int16* src_dy_row = src_gradient_y; + + for (int y = 0; y < patch_size; ++y, templ_row += patch_size, + src_dx_row += patch_size, src_dy_row += patch_size, + extract_row += patch_size) { + const int fp_y = top_fixed + (y << 16); + for (int x = 0; x < patch_size; ++x) { + const int fp_x = left_fixed + (x << 16); + int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y); + int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y); + + // Combine the two Jacobians. + // Right-shift by one to account for the fact that we add + // two Jacobians. + int32 dx = (src_dx_row[x] + target_dx) >> 1; + int32 dy = (src_dy_row[x] + target_dy) >> 1; + + // The current residual b - h(q) == extracted - (template + offset) + int32 diff = static_cast(extract_row[x]) - + static_cast(templ_row[x]) - + bright_offset; + + jtj[0] += dx * dx; + jtj[1] += dx * dy; + jtj[2] += dy * dy; + + jtr[0] += dx * diff; + jtr[1] += dy * diff; + + sum_diff += diff; + } + } + + const float jtr1_float = static_cast(jtr[0]); + const float jtr2_float = static_cast(jtr[1]); + + // Add some baseline stability to the system. + jtj[0] += kEsmRegularizer; + jtj[2] += kEsmRegularizer; + + const int64 prod1 = static_cast(jtj[0]) * jtj[2]; + const int64 prod2 = static_cast(jtj[1]) * jtj[1]; + + // One ESM step. + const float jtj_1[4] = { static_cast(jtj[2]), + static_cast(-jtj[1]), + static_cast(-jtj[1]), + static_cast(jtj[0]) }; + const double det_inv = 1.0 / static_cast(prod1 - prod2); + + g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float); + g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float); + + if (kDoBrightnessNormalize) { + bright_offset += + static_cast(area_inv * static_cast(sum_diff) + 0.5f); + } + + // Update top left position. + left_fixed = RealToFixed1616(p_x + g_x - wsize_float); + top_fixed = RealToFixed1616(p_y + g_y - wsize_float); + + left_trunc = left_fixed >> 16; + top_trunc = top_fixed >> 16; + + // Abort iterations if we go out of borders. + if (left_trunc < 0 || top_trunc < 0 || + (left_trunc + patch_size) >= J_x.width_less_one_ || + (top_trunc + patch_size) >= J_y.height_less_one_) { + break; + } + } // Iteration. + + // Copy value back into output. + *out_g_x = g_x; + *out_g_y = g_y; + return true; +} + + +bool OpticalFlow::FindFlowAtPointReversible( + const int level, const float u_x, const float u_y, + const bool reverse_flow, + float* flow_x, float* flow_y) const { + const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_; + const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_; + + // Images I (prev) and J (next). + const Image& img_I = *frame_a.GetPyramidSqrt2Level(level * 2); + const Image& img_J = *frame_b.GetPyramidSqrt2Level(level * 2); + + // Computed gradients. + const Image& I_x = *frame_a.GetSpatialX(level); + const Image& I_y = *frame_a.GetSpatialY(level); + const Image& J_x = *frame_b.GetSpatialX(level); + const Image& J_y = *frame_b.GetSpatialY(level); + + // Shrink factor from original. + const float shrink_factor = (1 << level); + + // Image position vector (p := u^l), scaled for this level. + const float scaled_p_x = u_x / shrink_factor; + const float scaled_p_y = u_y / shrink_factor; + + float scaled_flow_x = *flow_x / shrink_factor; + float scaled_flow_y = *flow_y / shrink_factor; + + // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level, + // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y); + + const bool success = kUseEsm ? + FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y, + scaled_p_x, scaled_p_y, + &scaled_flow_x, &scaled_flow_y) : + FindFlowAtPoint_LK(img_I, img_J, I_x, I_y, + scaled_p_x, scaled_p_y, + &scaled_flow_x, &scaled_flow_y); + + *flow_x = scaled_flow_x * shrink_factor; + *flow_y = scaled_flow_y * shrink_factor; + + return success; +} + + +bool OpticalFlow::FindFlowAtPointSingleLevel( + const int level, + const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const { + if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) { + return false; + } + + if (filter_by_fb_error) { + const float new_position_x = u_x + *flow_x; + const float new_position_y = u_y + *flow_y; + + float reverse_flow_x = 0.0f; + float reverse_flow_y = 0.0f; + + // Now find the backwards flow and confirm it lines up with the original + // starting point. + if (!FindFlowAtPointReversible(level, new_position_x, new_position_y, + true, + &reverse_flow_x, &reverse_flow_y)) { + LOGE("Backward error!"); + return false; + } + + const float discrepancy_length = + sqrtf(Square(*flow_x + reverse_flow_x) + + Square(*flow_y + reverse_flow_y)); + + const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y)); + + return discrepancy_length < + (kMaxForwardBackwardErrorAllowed * flow_length); + } + + return true; +} + + +// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details. +bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const { + const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment, + kNumPyramidLevels - kNumCacheLevels); + + // For every level in the pyramid, update the coordinates of the best match. + for (int l = max_level - 1; l >= 0; --l) { + if (!FindFlowAtPointSingleLevel(l, u_x, u_y, + filter_by_fb_error, flow_x, flow_y)) { + return false; + } + } + + return true; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h new file mode 100644 index 0000000000..1329927b99 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h @@ -0,0 +1,111 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +using namespace tensorflow; + +namespace tf_tracking { + +class FlowCache; + +// Class encapsulating all the data and logic necessary for performing optical +// flow. +class OpticalFlow { + public: + explicit OpticalFlow(const OpticalFlowConfig* const config); + + // Add a new frame to the optical flow. Will update all the non-keypoint + // related member variables. + // + // new_frame should be a buffer of grayscale values, one byte per pixel, + // at the original frame_width and frame_height used to initialize the + // OpticalFlow object. Downsampling will be handled internally. + // + // time_stamp should be a time in milliseconds that later calls to this and + // other methods will be relative to. + void NextFrame(const ImageData* const image_data); + + // An implementation of the Lucas-Kanade Optical Flow algorithm. + static bool FindFlowAtPoint_LK(const Image& img_I, + const Image& img_J, + const Image& I_x, + const Image& I_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y); + + // Pointwise flow using translational 2dof ESM. + static bool FindFlowAtPoint_ESM(const Image& img_I, + const Image& img_J, + const Image& I_x, + const Image& I_y, + const Image& J_x, + const Image& J_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y); + + // Finds the flow using a specific level, in either direction. + // If reversed, the coordinates are in the context of the latest + // frame, not the frame before it. + // All coordinates used in parameters are global, not scaled. + bool FindFlowAtPointReversible( + const int level, const float u_x, const float u_y, + const bool reverse_flow, + float* final_x, float* final_y) const; + + // Finds the flow using a specific level, filterable by forward-backward + // error. All coordinates used in parameters are global, not scaled. + bool FindFlowAtPointSingleLevel(const int level, + const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const; + + // Pyramidal optical-flow using all levels. + bool FindFlowAtPointPyramidal(const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const; + + private: + const OpticalFlowConfig* const config_; + + const ImageData* frame1_; + const ImageData* frame2_; + + // Size of the internally allocated images (after original is downsampled). + const Size working_size_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h new file mode 100755 index 0000000000..6240591cf2 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/sprite.h @@ -0,0 +1,205 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ + +#include +#include + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" + +#ifndef __RENDER_OPENGL__ +#error sprite.h should not included if OpenGL is not enabled by platform.h +#endif + +namespace tf_tracking { + +// This class encapsulates the logic necessary to load an render image data +// at the same aspect ratio as the original source. +class Sprite { + public: + // Only create Sprites when you have an OpenGl context. + explicit Sprite(const Image& image) { + LoadTexture(image, NULL); + } + + Sprite(const Image& image, const BoundingBox* const area) { + LoadTexture(image, area); + } + + // Also, try to only delete a Sprite when holding an OpenGl context. + ~Sprite() { + glDeleteTextures(1, &texture_); + } + + inline int GetWidth() const { + return actual_width_; + } + + inline int GetHeight() const { + return actual_height_; + } + + // Draw the sprite at 0,0 - original width/height in the current reference + // frame. Any transformations desired must be applied before calling this + // function. + void Draw() const { + const float float_width = static_cast(actual_width_); + const float float_height = static_cast(actual_height_); + + // Where it gets rendered to. + const float vertices[] = { 0.0f, 0.0f, 0.0f, + 0.0f, float_height, 0.0f, + float_width, 0.0f, 0.0f, + float_width, float_height, 0.0f, + }; + + // The coordinates the texture gets drawn from. + const float max_x = float_width / texture_width_; + const float max_y = float_height / texture_height_; + const float textureVertices[] = { + 0, 0, + 0, max_y, + max_x, 0, + max_x, max_y, + }; + + glEnable(GL_TEXTURE_2D); + glBindTexture(GL_TEXTURE_2D, texture_); + + glEnableClientState(GL_VERTEX_ARRAY); + glEnableClientState(GL_TEXTURE_COORD_ARRAY); + + glVertexPointer(3, GL_FLOAT, 0, vertices); + glTexCoordPointer(2, GL_FLOAT, 0, textureVertices); + + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + glDisableClientState(GL_VERTEX_ARRAY); + glDisableClientState(GL_TEXTURE_COORD_ARRAY); + } + + private: + inline int GetNextPowerOfTwo(const int number) const { + int power_of_two = 1; + while (power_of_two < number) { + power_of_two *= 2; + } + return power_of_two; + } + + // TODO(andrewharp): Allow sprites to have their textures reloaded. + void LoadTexture(const Image& texture_source, + const BoundingBox* const area) { + glEnable(GL_TEXTURE_2D); + + glGenTextures(1, &texture_); + + glBindTexture(GL_TEXTURE_2D, texture_); + + int left = 0; + int top = 0; + + if (area != NULL) { + // If a sub-region was provided to pull the texture from, use that. + left = area->left_; + top = area->top_; + actual_width_ = area->GetWidth(); + actual_height_ = area->GetHeight(); + } else { + actual_width_ = texture_source.GetWidth(); + actual_height_ = texture_source.GetHeight(); + } + + // The textures must be a power of two, so find the sizes that are large + // enough to contain the image data. + texture_width_ = GetNextPowerOfTwo(actual_width_); + texture_height_ = GetNextPowerOfTwo(actual_height_); + + bool allocated_data = false; + uint8* texture_data; + + // Except in the lucky case where we're not using a sub-region of the + // original image AND the source data has dimensions that are power of two, + // care must be taken to copy data at the appropriate source and destination + // strides so that the final block can be copied directly into texture + // memory. + // TODO(andrewharp): Figure out if data can be pulled directly from the + // source image with some alignment modifications. + if (left != 0 || top != 0 || + actual_width_ != texture_source.GetWidth() || + actual_height_ != texture_source.GetHeight()) { + texture_data = new uint8[actual_width_ * actual_height_]; + + for (int y = 0; y < actual_height_; ++y) { + memcpy(texture_data + actual_width_ * y, + texture_source[top + y] + left, + actual_width_ * sizeof(uint8)); + } + allocated_data = true; + } else { + // Cast away const-ness because for some reason glTexSubImage2D wants + // a non-const data pointer. + texture_data = const_cast(texture_source.data()); + } + + glTexImage2D(GL_TEXTURE_2D, + 0, + GL_LUMINANCE, + texture_width_, + texture_height_, + 0, + GL_LUMINANCE, + GL_UNSIGNED_BYTE, + NULL); + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1); + glTexSubImage2D(GL_TEXTURE_2D, + 0, + 0, + 0, + actual_width_, + actual_height_, + GL_LUMINANCE, + GL_UNSIGNED_BYTE, + texture_data); + + if (allocated_data) { + delete(texture_data); + } + + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + } + + // The id for the texture on the GPU. + GLuint texture_; + + // The width and height to be used for display purposes, referring to the + // dimensions of the original texture. + int actual_width_; + int actual_height_; + + // The allocated dimensions of the texture data, which must be powers of 2. + int texture_width_; + int texture_height_; + + TF_DISALLOW_COPY_AND_ASSIGN(Sprite); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.cc b/tensorflow/examples/android/jni/object_tracking/time_log.cc new file mode 100644 index 0000000000..cb1f3c23c8 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/time_log.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" + +using namespace tensorflow; + +#ifdef LOG_TIME +// Storage for logging functionality. +int num_time_logs = 0; +LogEntry time_logs[NUM_LOGS]; + +int num_avg_entries = 0; +AverageEntry avg_entries[NUM_LOGS]; +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h new file mode 100644 index 0000000000..ec539a1b3b --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/time_log.h @@ -0,0 +1,138 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions for performance profiling. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#ifdef LOG_TIME + +// Blend constant for running average. +#define ALPHA 0.98f +#define NUM_LOGS 100 + +struct LogEntry { + const char* id; + int64 time_stamp; +}; + +struct AverageEntry { + const char* id; + float average_duration; +}; + +// Storage for keeping track of this frame's values. +extern int num_time_logs; +extern LogEntry time_logs[NUM_LOGS]; + +// Storage for keeping track of average values (each entry may not be printed +// out each frame). +extern AverageEntry avg_entries[NUM_LOGS]; +extern int num_avg_entries; + +// Call this at the start of a logging phase. +inline static void ResetTimeLog() { + num_time_logs = 0; +} + + +// Log a message to be printed out when printTimeLog is called, along with the +// amount of time in ms that has passed since the last call to this function. +inline static void TimeLog(const char* const str) { + LOGV("%s", str); + if (num_time_logs >= NUM_LOGS) { + LOGE("Out of log entries!"); + return; + } + + time_logs[num_time_logs].id = str; + time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos(); + ++num_time_logs; +} + + +inline static float Blend(float old_val, float new_val) { + return ALPHA * old_val + (1.0f - ALPHA) * new_val; +} + + +inline static float UpdateAverage(const char* str, const float new_val) { + for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) { + AverageEntry* const entry = avg_entries + entry_num; + if (str == entry->id) { + entry->average_duration = Blend(entry->average_duration, new_val); + return entry->average_duration; + } + } + + if (num_avg_entries >= NUM_LOGS) { + LOGE("Too many log entries!"); + } + + // If it wasn't there already, add it. + avg_entries[num_avg_entries].id = str; + avg_entries[num_avg_entries].average_duration = new_val; + ++num_avg_entries; + + return new_val; +} + + +// Prints out all the timeLog statements in chronological order with the +// interval that passed between subsequent statements. The total time between +// the first and last statements is printed last. +inline static void PrintTimeLog() { + LogEntry* last_time = time_logs; + + float average_running_total = 0.0f; + + for (int i = 0; i < num_time_logs; ++i) { + LogEntry* const this_time = time_logs + i; + + const float curr_time = + (this_time->time_stamp - last_time->time_stamp) / 1000000.0f; + + const float avg_time = UpdateAverage(this_time->id, curr_time); + average_running_total += avg_time; + + LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time); + last_time = this_time; + } + + const float total_time = + (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f; + + LOGD("TOTAL TIME: %6.3fms %6.4fms\n", + total_time, average_running_total); + LOGD(" "); +} +#else +inline static void ResetTimeLog() {} + +inline static void TimeLog(const char* const str) { + LOGV("%s", str); +} + +inline static void PrintTimeLog() {} +#endif + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc new file mode 100644 index 0000000000..823fb3a90e --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc @@ -0,0 +1,163 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" + +namespace tf_tracking { + +static const float kInitialDistance = 20.0f; + +static void InitNormalized(const Image& src_image, + const BoundingBox& position, + Image* const dst_image) { + BoundingBox scaled_box(position); + CopyArea(src_image, scaled_box, dst_image); + NormalizeImage(dst_image); +} + +TrackedObject::TrackedObject(const std::string& id, + const Image& image, + const BoundingBox& bounding_box, + ObjectModelBase* const model) + : id_(id), + last_known_position_(bounding_box), + last_detection_position_(bounding_box), + position_last_computed_time_(-1), + object_model_(model), + last_detection_thumbnail_(kNormalizedThumbnailSize, + kNormalizedThumbnailSize), + last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize), + tracked_correlation_(0.0f), + tracked_match_score_(0.0), + num_consecutive_frames_below_threshold_(0), + allowable_detection_distance_(Square(kInitialDistance)) { + InitNormalized(image, bounding_box, &last_detection_thumbnail_); +} + +TrackedObject::~TrackedObject() {} + +void TrackedObject::UpdatePosition(const BoundingBox& new_position, + const int64 timestamp, + const ImageData& image_data, + const bool authoratative) { + last_known_position_ = new_position; + position_last_computed_time_ = timestamp; + + InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_); + + const float last_localization_correlation = ComputeCrossCorrelation( + last_detection_thumbnail_.data(), + last_frame_thumbnail_.data(), + last_frame_thumbnail_.data_size_); + LOGV("Tracked correlation to last localization: %.6f", + last_localization_correlation); + + // Correlation to object model, if it exists. + if (object_model_ != NULL) { + tracked_correlation_ = + object_model_->GetMaxCorrelation(last_frame_thumbnail_); + LOGV("Tracked correlation to model: %.6f", + tracked_correlation_); + + tracked_match_score_ = + object_model_->GetMatchScore(new_position, image_data); + LOGV("Tracked match score with model: %.6f", + tracked_match_score_.value); + } else { + // If there's no model to check against, set the tracked correlation to + // simply be the correlation to the last set position. + tracked_correlation_ = last_localization_correlation; + tracked_match_score_ = MatchScore(0.0f); + } + + // Determine if it's still being tracked. + if (tracked_correlation_ >= kMinimumCorrelationForTracking && + tracked_match_score_ >= kMinimumMatchScore) { + num_consecutive_frames_below_threshold_ = 0; + + if (object_model_ != NULL) { + object_model_->TrackStep(last_known_position_, *image_data.GetImage(), + *image_data.GetIntegralImage(), authoratative); + } + } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) { + if (num_consecutive_frames_below_threshold_ < 1000) { + LOGD("Tracked match score is way too low (%.6f), aborting track.", + tracked_match_score_.value); + } + + // Add an absurd amount of missed frames so that all heuristics will + // consider it a lost track. + num_consecutive_frames_below_threshold_ += 1000; + + if (object_model_ != NULL) { + object_model_->TrackLost(); + } + } else { + ++num_consecutive_frames_below_threshold_; + allowable_detection_distance_ *= 1.1f; + } +} + +void TrackedObject::OnDetection(ObjectModelBase* const model, + const BoundingBox& detection_position, + const MatchScore match_score, + const int64 timestamp, + const ImageData& image_data) { + const float overlap = detection_position.PascalScore(last_known_position_); + if (overlap > kPositionOverlapThreshold) { + // If the position agreement with the current tracked position is good + // enough, lock all the current unlocked examples. + object_model_->TrackConfirmed(); + num_consecutive_frames_below_threshold_ = 0; + } + + // Before relocalizing, make sure the new proposed position is better than + // the existing position by a small amount to prevent thrashing. + if (match_score <= tracked_match_score_ + kMatchScoreBuffer) { + LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f", + match_score.value, tracked_match_score_.value, + kMatchScoreBuffer.value); + return; + } + + LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to " + "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f", + last_known_position_.left_, last_known_position_.top_, + last_known_position_.GetWidth(), last_known_position_.GetHeight(), + detection_position.left_, detection_position.top_, + detection_position.GetWidth(), detection_position.GetHeight(), + match_score.value, tracked_match_score_.value); + + if (overlap < kPositionOverlapThreshold) { + // The path might be good, it might be bad, but it's no longer a path + // since we're moving the box to a new position, so just nuke it from + // orbit to be safe. + object_model_->TrackLost(); + } + + object_model_ = model; + + // Reset the last detected appearance. + InitNormalized( + *image_data.GetImage(), detection_position, &last_detection_thumbnail_); + + num_consecutive_frames_below_threshold_ = 0; + last_detection_position_ = detection_position; + + UpdatePosition(detection_position, timestamp, image_data, false); + allowable_detection_distance_ = Square(kInitialDistance); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h new file mode 100644 index 0000000000..5580cd2b89 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h @@ -0,0 +1,191 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ + +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" + +namespace tf_tracking { + +// A TrackedObject is a specific instance of an ObjectModel, with a known +// position in the world. +// It provides the last known position and number of recent detection failures, +// in addition to the more general appearance data associated with the object +// class (which is in ObjectModel). +// TODO(andrewharp): Make getters/setters follow styleguide. +class TrackedObject { + public: + TrackedObject(const std::string& id, + const Image& image, + const BoundingBox& bounding_box, + ObjectModelBase* const model); + + ~TrackedObject(); + + void UpdatePosition(const BoundingBox& new_position, + const int64 timestamp, + const ImageData& image_data, + const bool authoratative); + + // This method is called when the tracked object is detected at a + // given position, and allows the associated Model to grow and/or prune + // itself based on where the detection occurred. + void OnDetection(ObjectModelBase* const model, + const BoundingBox& detection_position, + const MatchScore match_score, + const int64 timestamp, + const ImageData& image_data); + + // Called when there's no detection of the tracked object. This will cause + // a tracking failure after enough consecutive failures if the area under + // the current bounding box also doesn't meet a minimum correlation threshold + // with the model. + void OnDetectionFailure() {} + + inline bool IsVisible() const { + return tracked_correlation_ >= kMinimumCorrelationForTracking || + num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures; + } + + inline float GetCorrelation() { + return tracked_correlation_; + } + + inline MatchScore GetMatchScore() { + return tracked_match_score_; + } + + inline BoundingBox GetPosition() const { + return last_known_position_; + } + + inline BoundingBox GetLastDetectionPosition() const { + return last_detection_position_; + } + + inline const ObjectModelBase* GetModel() const { + return object_model_; + } + + inline const std::string& GetName() const { + return id_; + } + + inline void Draw() const { +#ifdef __RENDER_OPENGL__ + if (tracked_correlation_ < kMinimumCorrelationForTracking) { + glColor4f(MAX(0.0f, -tracked_correlation_), + MAX(0.0f, tracked_correlation_), + 0.0f, + 1.0f); + } else { + glColor4f(MAX(0.0f, -tracked_correlation_), + MAX(0.0f, tracked_correlation_), + 1.0f, + 1.0f); + } + + // Render the box itself. + BoundingBox temp_box(last_known_position_); + DrawBox(temp_box); + + // Render a box inside this one (in case the actual box is hidden). + const float kBufferSize = 1.0f; + temp_box.left_ -= kBufferSize; + temp_box.top_ -= kBufferSize; + temp_box.right_ += kBufferSize; + temp_box.bottom_ += kBufferSize; + DrawBox(temp_box); + + // Render one outside as well. + temp_box.left_ -= -2.0f * kBufferSize; + temp_box.top_ -= -2.0f * kBufferSize; + temp_box.right_ += -2.0f * kBufferSize; + temp_box.bottom_ += -2.0f * kBufferSize; + DrawBox(temp_box); +#endif + } + + // Get current object's num_consecutive_frames_below_threshold_. + inline int64 GetNumConsecutiveFramesBelowThreshold() { + return num_consecutive_frames_below_threshold_; + } + + // Reset num_consecutive_frames_below_threshold_ to 0. + inline void resetNumConsecutiveFramesBelowThreshold() { + num_consecutive_frames_below_threshold_ = 0; + } + + inline float GetAllowableDistanceSquared() const { + return allowable_detection_distance_; + } + + private: + // The unique id used throughout the system to identify this + // tracked object. + const std::string id_; + + // The last known position of the object. + BoundingBox last_known_position_; + + // The last known position of the object. + BoundingBox last_detection_position_; + + // When the position was last computed. + int64 position_last_computed_time_; + + // The object model this tracked object is representative of. + ObjectModelBase* object_model_; + + Image last_detection_thumbnail_; + + Image last_frame_thumbnail_; + + // The correlation of the object model with the preview frame at its last + // tracked position. + float tracked_correlation_; + + MatchScore tracked_match_score_; + + // The number of consecutive frames that the tracked position for this object + // has been under the correlation threshold. + int num_consecutive_frames_below_threshold_; + + float allowable_detection_distance_; + + friend std::ostream& operator<<(std::ostream& stream, + const TrackedObject& tracked_object); + + TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject); +}; + +inline std::ostream& operator<<(std::ostream& stream, + const TrackedObject& tracked_object) { + stream << tracked_object.id_ + << " " << tracked_object.last_known_position_ + << " " << tracked_object.position_last_computed_time_ + << " " << tracked_object.num_consecutive_frames_below_threshold_ + << " " << tracked_object.object_model_ + << " " << tracked_object.tracked_correlation_; + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h new file mode 100644 index 0000000000..cbdfc408c6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/utils.h @@ -0,0 +1,386 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ + +#include +#include +#include + +#include // for std::abs(float) + +#ifndef HAVE_CLOCK_GETTIME +// Use gettimeofday() instead of clock_gettime(). +#include +#endif // ifdef HAVE_CLOCK_GETTIME + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using namespace tensorflow; + +// TODO(andrewharp): clean up these macros to use the codebase statndard. + +// A very small number, generally used as the tolerance for accumulated +// floating point errors in bounds-checks. +#define EPSILON 0.00001f + +#define SAFE_DELETE(pointer) {\ + if ((pointer) != NULL) {\ + LOGV("Safe deleting pointer: %s", #pointer);\ + delete (pointer);\ + (pointer) = NULL;\ + } else {\ + LOGV("Pointer already null: %s", #pointer);\ + }\ +} + + +#ifdef __GOOGLE__ + +#define CHECK_ALWAYS(condition, format, ...) {\ + CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ +} + +#define SCHECK(condition, format, ...) {\ + DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ +} + +#else + +#define CHECK_ALWAYS(condition, format, ...) {\ + if (!(condition)) {\ + LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\ + abort();\ + }\ +} + +#ifdef SANITY_CHECKS +#define SCHECK(condition, format, ...) {\ + CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\ +} +#else +#define SCHECK(condition, format, ...) {} +#endif // SANITY_CHECKS + +#endif // __GOOGLE__ + + +#ifndef MAX +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) +#endif +#ifndef MIN +#define MIN(a, b) (((a) > (b)) ? (b) : (a)) +#endif + + + +inline static int64 CurrentThreadTimeNanos() { +#ifdef HAVE_CLOCK_GETTIME + struct timespec tm; + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm); + return tm.tv_sec * 1000000000LL + tm.tv_nsec; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000000000 + tv.tv_usec * 1000; +#endif +} + + +inline static int64 CurrentRealTimeMillis() { +#ifdef HAVE_CLOCK_GETTIME + struct timespec tm; + clock_gettime(CLOCK_MONOTONIC, &tm); + return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000 + tv.tv_usec / 1000; +#endif +} + + +template +inline static T Square(const T a) { + return a * a; +} + + +template +inline static T Clip(const T a, const T floor, const T ceil) { + SCHECK(ceil >= floor, "Bounds mismatch!"); + return (a <= floor) ? floor : ((a >= ceil) ? ceil : a); +} + + +template +inline static int Floor(const T a) { + return static_cast(a); +} + + +template +inline static int Ceil(const T a) { + return Floor(a) + 1; +} + + +template +inline static bool InRange(const T a, const T min, const T max) { + return (a >= min) && (a <= max); +} + + +inline static bool ValidIndex(const int a, const int max) { + return (a >= 0) && (a < max); +} + + +inline bool NearlyEqual(const float a, const float b, const float tolerance) { + return std::abs(a - b) < tolerance; +} + + +inline bool NearlyEqual(const float a, const float b) { + return NearlyEqual(a, b, EPSILON); +} + + +template +inline static int Round(const float a) { + return (a - static_cast(floor(a) > 0.5f) ? ceil(a) : floor(a)); +} + + +template +inline static void Swap(T* const a, T* const b) { + // Cache out the VALUE of what's at a. + T tmp = *a; + *a = *b; + + *b = tmp; +} + + +static inline float randf() { + return rand() / static_cast(RAND_MAX); +} + +static inline float randf(const float min_value, const float max_value) { + return randf() * (max_value - min_value) + min_value; +} + +static inline uint16 RealToFixed115(const float real_number) { + SCHECK(InRange(real_number, 0.0f, 2048.0f), + "Value out of range! %.2f", real_number); + + static const float kMult = 32.0f; + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; + return static_cast(real_number * kMult + round_add); +} + +static inline float FixedToFloat115(const uint16 fp_number) { + const float kDiv = 32.0f; + return (static_cast(fp_number) / kDiv); +} + +static inline int RealToFixed1616(const float real_number) { + static const float kMult = 65536.0f; + SCHECK(InRange(real_number, -kMult, kMult), + "Value out of range! %.2f", real_number); + + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; + return static_cast(real_number * kMult + round_add); +} + +static inline float FixedToFloat1616(const int fp_number) { + const float kDiv = 65536.0f; + return (static_cast(fp_number) / kDiv); +} + +template +// produces numbers in range [0,2*M_PI] (rather than -PI,PI) +inline T FastAtan2(const T y, const T x) { + static const T coeff_1 = (T)(M_PI / 4.0); + static const T coeff_2 = (T)(3.0 * coeff_1); + const T abs_y = fabs(y); + T angle; + if (x >= 0) { + T r = (x - abs_y) / (x + abs_y); + angle = coeff_1 - coeff_1 * r; + } else { + T r = (x + abs_y) / (abs_y - x); + angle = coeff_2 - coeff_1 * r; + } + static const T PI_2 = 2.0 * M_PI; + return y < 0 ? PI_2 - angle : angle; +} + +#define NELEMS(X) (sizeof(X) / sizeof(X[0])) + +namespace tf_tracking { + +#ifdef __ARM_NEON +float ComputeMeanNeon(const float* const values, const int num_vals); + +float ComputeStdDevNeon(const float* const values, const int num_vals, + const float mean); + +float ComputeWeightedMeanNeon(const float* const values, + const float* const weights, const int num_vals); + +float ComputeCrossCorrelationNeon(const float* const values1, + const float* const values2, + const int num_vals); +#endif + +inline float ComputeMeanCpu(const float* const values, const int num_vals) { + // Get mean. + float sum = values[0]; + for (int i = 1; i < num_vals; ++i) { + sum += values[i]; + } + return sum / static_cast(num_vals); +} + + +inline float ComputeMean(const float* const values, const int num_vals) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) : +#endif + ComputeMeanCpu(values, num_vals); +} + + +inline float ComputeStdDevCpu(const float* const values, + const int num_vals, + const float mean) { + // Get Std dev. + float squared_sum = 0.0f; + for (int i = 0; i < num_vals; ++i) { + squared_sum += Square(values[i] - mean); + } + return sqrt(squared_sum / static_cast(num_vals)); +} + + +inline float ComputeStdDev(const float* const values, + const int num_vals, + const float mean) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) : +#endif + ComputeStdDevCpu(values, num_vals, mean); +} + + +// TODO(andrewharp): Accelerate with NEON. +inline float ComputeWeightedMean(const float* const values, + const float* const weights, + const int num_vals) { + float sum = 0.0f; + float total_weight = 0.0f; + for (int i = 0; i < num_vals; ++i) { + sum += values[i] * weights[i]; + total_weight += weights[i]; + } + return sum / num_vals; +} + + +inline float ComputeCrossCorrelationCpu(const float* const values1, + const float* const values2, + const int num_vals) { + float sxy = 0.0f; + for (int offset = 0; offset < num_vals; ++offset) { + sxy += values1[offset] * values2[offset]; + } + + const float cross_correlation = sxy / num_vals; + + return cross_correlation; +} + + +inline float ComputeCrossCorrelation(const float* const values1, + const float* const values2, + const int num_vals) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals) + : +#endif + ComputeCrossCorrelationCpu(values1, values2, num_vals); +} + + +inline void NormalizeNumbers(float* const values, const int num_vals) { + // Find the mean and then subtract so that the new mean is 0.0. + const float mean = ComputeMean(values, num_vals); + VLOG(2) << "Mean is " << mean; + float* curr_data = values; + for (int i = 0; i < num_vals; ++i) { + *curr_data -= mean; + curr_data++; + } + + // Now divide by the std deviation so the new standard deviation is 1.0. + // The numbers might all be identical (and thus shifted to 0.0 now), + // so only scale by the standard deviation if this is not the case. + const float std_dev = ComputeStdDev(values, num_vals, 0.0f); + if (std_dev > 0.0f) { + VLOG(2) << "Std dev is " << std_dev; + curr_data = values; + for (int i = 0; i < num_vals; ++i) { + *curr_data /= std_dev; + curr_data++; + } + } +} + + +// Returns the determinant of a 2x2 matrix. +template +inline T FindDeterminant2x2(const T* const a) { + // Determinant: (ad - bc) + return a[0] * a[3] - a[1] * a[2]; +} + + +// Finds the inverse of a 2x2 matrix. +// Returns true upon success, false if the matrix is not invertible. +template +inline bool Invert2x2(const T* const a, float* const a_inv) { + const float det = static_cast(FindDeterminant2x2(a)); + if (fabs(det) < EPSILON) { + return false; + } + const float inv_det = 1.0f / det; + + a_inv[0] = inv_det * static_cast(a[3]); // d + a_inv[1] = inv_det * static_cast(-a[1]); // -b + a_inv[2] = inv_det * static_cast(-a[2]); // -c + a_inv[3] = inv_det * static_cast(a[0]); // a + + return true; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc new file mode 100755 index 0000000000..5a5250e32e --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc @@ -0,0 +1,151 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NEON implementations of Image methods for compatible devices. Control +// should never enter this compilation unit on incompatible devices. + +#ifdef __ARM_NEON + +#include + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +inline static float GetSum(const float32x4_t& values) { + static float32_t summed_values[4]; + vst1q_f32(summed_values, values); + return summed_values[0] + + summed_values[1] + + summed_values[2] + + summed_values[3]; +} + + +float ComputeMeanNeon(const float* const values, const int num_vals) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals = (const float32_t* const) values; + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset])); + } + + // Pull the accumulated values into a single variable. + float sum = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + sum += values[offset]; + } + + const float mean_neon = sum / static_cast(num_vals); + +#ifdef SANITY_CHECKS + const float mean_cpu = ComputeMeanCpu(values, num_vals); + SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals), + "Neon mismatch with CPU mean! %.10f vs %.10f", + mean_neon, mean_cpu); +#endif + + return mean_neon; +} + + +float ComputeStdDevNeon(const float* const values, + const int num_vals, const float mean) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals = (const float32_t* const) values; + const float32x4_t mean_vec = vdupq_n_f32(-mean); + + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + const float32x4_t deltas = + vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset])); + + accum = vmlaq_f32(accum, deltas, deltas); + } + + // Pull the accumulated values into a single variable. + float squared_sum = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + squared_sum += Square(values[offset] - mean); + } + + const float std_dev_neon = sqrt(squared_sum / static_cast(num_vals)); + +#ifdef SANITY_CHECKS + const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean); + SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals), + "Neon mismatch with CPU std dev! %.10f vs %.10f", + std_dev_neon, std_dev_cpu); +#endif + + return std_dev_neon; +} + + +float ComputeCrossCorrelationNeon(const float* const values1, + const float* const values2, + const int num_vals) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals1 = (const float32_t* const) values1; + const float32_t* const arm_vals2 = (const float32_t* const) values2; + + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + accum = vmlaq_f32(accum, + vld1q_f32(&arm_vals1[offset]), + vld1q_f32(&arm_vals2[offset])); + } + + // Pull the accumulated values into a single variable. + float sxy = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + sxy += values1[offset] * values2[offset]; + } + + const float cross_correlation_neon = sxy / num_vals; + +#ifdef SANITY_CHECKS + const float cross_correlation_cpu = + ComputeCrossCorrelationCpu(values1, values2, num_vals); + SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu, + EPSILON * num_vals), + "Neon mismatch with CPU cross correlation! %.10f vs %.10f", + cross_correlation_neon, cross_correlation_cpu); +#endif + + return cross_correlation_neon; +} + +} // namespace tf_tracking + +#endif // __ARM_NEON diff --git a/tensorflow/examples/android/proto/box_coder.proto b/tensorflow/examples/android/proto/box_coder.proto new file mode 100644 index 0000000000..8576294110 --- /dev/null +++ b/tensorflow/examples/android/proto/box_coder.proto @@ -0,0 +1,42 @@ +syntax = "proto2"; + +package org_tensorflow_demo; + +// Prior for a single feature (like minimum x coordinate, width, area, etc.) +message BoxCoderPrior { + optional float mean = 1 [default = 0.0]; + optional float stddev = 2 [default = 1.0]; +}; + +// Box encoding/decoding configuration for a single box. +message BoxCoderOptions { + // Number of priors must match the number of values used to encoded + // values which is derived from the use_... flags below. + repeated BoxCoderPrior priors = 1; + + // Minimum/maximum X/Y of the four corners are used as features. + // Order: MinX, MinY, MaxX, MaxY. + // Number of values: 4. + optional bool use_corners = 2 [default = true]; + + // Width and height of the box in this order. + // Number of values: 2. + optional bool use_width_height = 3 [default = false]; + + // Coordinates of the center of the box. + // Order: X, Y. + // Number of values: 2. + optional bool use_center = 4 [default = false]; + + // Area of the box. + // Number of values: 1. + optional bool use_area = 5 [default = false]; +}; + +// Options for MultiBoxCoder which is a encoder/decoder for a fixed number of +// boxes. +// A list of BoxCoderOptions that allows for storing multiple box coder options +// in a single file. +message MultiBoxCoderOptions { + repeated BoxCoderOptions box_coder = 1; +}; diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml new file mode 100644 index 0000000000..674f25785a --- /dev/null +++ b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml @@ -0,0 +1,30 @@ + + + + + + + + diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/examples/android/res/values/base-strings.xml index 93cfe0dac2..f6c57d5030 100644 --- a/tensorflow/examples/android/res/values/base-strings.xml +++ b/tensorflow/examples/android/res/values/base-strings.xml @@ -1,6 +1,6 @@