diff options
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/frame_pair.cc')
-rw-r--r-- | tensorflow/examples/android/jni/object_tracking/frame_pair.cc | 308 |
1 files changed, 308 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc new file mode 100644 index 0000000000..fa86e2363c --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc @@ -0,0 +1,308 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <float.h> + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" + +namespace tf_tracking { + +void FramePair::Init(const int64 start_time, const int64 end_time) { + start_time_ = start_time; + end_time_ = end_time; + memset(optical_flow_found_keypoint_, false, + sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints); + number_of_keypoints_ = 0; +} + +void FramePair::AdjustBox(const BoundingBox box, + float* const translation_x, + float* const translation_y, + float* const scale_x, + float* const scale_y) const { + static float weights[kMaxKeypoints]; + static Point2f deltas[kMaxKeypoints]; + memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints); + + BoundingBox resized_box(box); + resized_box.Scale(0.4f, 0.4f); + FillWeights(resized_box, weights); + FillTranslations(deltas); + + const Point2f translation = GetWeightedMedian(weights, deltas); + + *translation_x = translation.x; + *translation_y = translation.y; + + const Point2f old_center = box.GetCenter(); + const int good_scale_points = + FillScales(old_center, translation, weights, deltas); + + // Default scale factor is 1 for x and y. + *scale_x = 1.0f; + *scale_y = 1.0f; + + // The assumption is that all deltas that make it to this stage with a + // correspondending optical_flow_found_keypoint_[i] == true are not in + // themselves degenerate. + // + // The degeneracy with scale arose because if the points are too close to the + // center of the objects, the scale ratio determination might be incalculable. + // + // The check for kMinNumInRange is not a degeneracy check, but merely an + // attempt to ensure some sort of stability. The actual degeneracy check is in + // the comparison to EPSILON in FillScales (which I've updated to return the + // number good remaining as well). + static const int kMinNumInRange = 5; + if (good_scale_points >= kMinNumInRange) { + const float scale_factor = GetWeightedMedianScale(weights, deltas); + + if (scale_factor > 0.0f) { + *scale_x = scale_factor; + *scale_y = scale_factor; + } + } +} + +int FramePair::FillWeights(const BoundingBox& box, + float* const weights) const { + // Compute the max score. + float max_score = -FLT_MAX; + float min_score = FLT_MAX; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (optical_flow_found_keypoint_[i]) { + max_score = MAX(max_score, frame1_keypoints_[i].score_); + min_score = MIN(min_score, frame1_keypoints_[i].score_); + } + } + + int num_in_range = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + weights[i] = 0.0f; + continue; + } + + const bool in_box = box.Contains(frame1_keypoints_[i].pos_); + if (in_box) { + ++num_in_range; + } + + // The weighting based off distance. Anything within the bounding box + // has a weight of 1, and everything outside of that is within the range + // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio. + float distance_score = 1.0f; + if (!in_box) { + const Point2f initial = box.GetCenter(); + const float sq_x_dist = + Square(initial.x - frame1_keypoints_[i].pos_.x); + const float sq_y_dist = + Square(initial.y - frame1_keypoints_[i].pos_.y); + const float squared_half_width = Square(box.GetWidth() / 2.0f); + const float squared_half_height = Square(box.GetHeight() / 2.0f); + + static const float kOutOfBoxMultiplier = 0.5f; + distance_score = kOutOfBoxMultiplier * + MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist); + } + + // The weighting based on relative score strength. kBaseScore - 1.0f. + float intrinsic_score = 1.0f; + if (max_score > min_score) { + static const float kBaseScore = 0.5f; + intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) / + (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore; + } + + // The final score will be in the range [0, 1]. + weights[i] = distance_score * intrinsic_score; + } + + return num_in_range; +} + +void FramePair::FillTranslations(Point2f* const translations) const { + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + translations[i].x = + frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x; + translations[i].y = + frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y; + } +} + +int FramePair::FillScales(const Point2f& old_center, + const Point2f& translation, + float* const weights, + Point2f* const scales) const { + int num_good = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + + const Keypoint keypoint1 = frame1_keypoints_[i]; + const Keypoint keypoint2 = frame2_keypoints_[i]; + + const float dist1_x = keypoint1.pos_.x - old_center.x; + const float dist1_y = keypoint1.pos_.y - old_center.y; + + const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x; + const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y; + + // Make sure that the scale makes sense; points too close to the center + // will result in either NaNs or infinite results for scale due to + // limited tracking and floating point resolution. + // Also check that the parity of the points is the same with respect to + // x and y, as we can't really make sense of data that has flipped. + if (((dist2_x > EPSILON && dist1_x > EPSILON) || + (dist2_x < -EPSILON && dist1_x < -EPSILON)) && + ((dist2_y > EPSILON && dist1_y > EPSILON) || + (dist2_y < -EPSILON && dist1_y < -EPSILON))) { + scales[i].x = dist2_x / dist1_x; + scales[i].y = dist2_y / dist1_y; + ++num_good; + } else { + weights[i] = 0.0f; + scales[i].x = 1.0f; + scales[i].y = 1.0f; + } + } + return num_good; +} + +struct WeightedDelta { + float weight; + float delta; +}; + +// Sort by delta, not by weight. +inline int WeightedDeltaCompare(const void* const a, const void* const b) { + return (reinterpret_cast<const WeightedDelta*>(a)->delta - + reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1; +} + +// Returns the median delta from a sorted set of weighted deltas. +static float GetMedian(const int num_items, + const WeightedDelta* const weighted_deltas, + const float sum) { + if (num_items == 0 || sum < EPSILON) { + return 0.0f; + } + + float current_weight = 0.0f; + const float target_weight = sum / 2.0f; + for (int i = 0; i < num_items; ++i) { + if (weighted_deltas[i].weight > 0.0f) { + current_weight += weighted_deltas[i].weight; + if (current_weight >= target_weight) { + return weighted_deltas[i].delta; + } + } + } + LOGW("Median not found! %d points, sum of %.2f", num_items, sum); + return 0.0f; +} + +Point2f FramePair::GetWeightedMedian( + const float* const weights, const Point2f* const deltas) const { + Point2f median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints]; + + // Compute median X value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + // Compute median Y value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + weighted_deltas[i].delta = deltas[i].y; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + return median_delta; +} + +float FramePair::GetWeightedMedianScale( + const float* const weights, const Point2f* const deltas) const { + float median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints * 2]; + + // Compute median scale value across x and y. + { + float total_weight = 0.0f; + + // Add X values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + // Add Y values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y; + const float weight = weights[i]; + weighted_deltas[i + kMaxKeypoints].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta), + WeightedDeltaCompare); + + median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight); + } + + return median_delta; +} + +} // namespace tf_tracking |