aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/tracked_object.cc')
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.cc163
1 files changed, 163 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
new file mode 100644
index 0000000000..823fb3a90e
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
@@ -0,0 +1,163 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+static const float kInitialDistance = 20.0f;
+
+static void InitNormalized(const Image<uint8>& src_image,
+ const BoundingBox& position,
+ Image<float>* const dst_image) {
+ BoundingBox scaled_box(position);
+ CopyArea(src_image, scaled_box, dst_image);
+ NormalizeImage(dst_image);
+}
+
+TrackedObject::TrackedObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ ObjectModelBase* const model)
+ : id_(id),
+ last_known_position_(bounding_box),
+ last_detection_position_(bounding_box),
+ position_last_computed_time_(-1),
+ object_model_(model),
+ last_detection_thumbnail_(kNormalizedThumbnailSize,
+ kNormalizedThumbnailSize),
+ last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
+ tracked_correlation_(0.0f),
+ tracked_match_score_(0.0),
+ num_consecutive_frames_below_threshold_(0),
+ allowable_detection_distance_(Square(kInitialDistance)) {
+ InitNormalized(image, bounding_box, &last_detection_thumbnail_);
+}
+
+TrackedObject::~TrackedObject() {}
+
+void TrackedObject::UpdatePosition(const BoundingBox& new_position,
+ const int64 timestamp,
+ const ImageData& image_data,
+ const bool authoratative) {
+ last_known_position_ = new_position;
+ position_last_computed_time_ = timestamp;
+
+ InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
+
+ const float last_localization_correlation = ComputeCrossCorrelation(
+ last_detection_thumbnail_.data(),
+ last_frame_thumbnail_.data(),
+ last_frame_thumbnail_.data_size_);
+ LOGV("Tracked correlation to last localization: %.6f",
+ last_localization_correlation);
+
+ // Correlation to object model, if it exists.
+ if (object_model_ != NULL) {
+ tracked_correlation_ =
+ object_model_->GetMaxCorrelation(last_frame_thumbnail_);
+ LOGV("Tracked correlation to model: %.6f",
+ tracked_correlation_);
+
+ tracked_match_score_ =
+ object_model_->GetMatchScore(new_position, image_data);
+ LOGV("Tracked match score with model: %.6f",
+ tracked_match_score_.value);
+ } else {
+ // If there's no model to check against, set the tracked correlation to
+ // simply be the correlation to the last set position.
+ tracked_correlation_ = last_localization_correlation;
+ tracked_match_score_ = MatchScore(0.0f);
+ }
+
+ // Determine if it's still being tracked.
+ if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
+ tracked_match_score_ >= kMinimumMatchScore) {
+ num_consecutive_frames_below_threshold_ = 0;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
+ *image_data.GetIntegralImage(), authoratative);
+ }
+ } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
+ if (num_consecutive_frames_below_threshold_ < 1000) {
+ LOGD("Tracked match score is way too low (%.6f), aborting track.",
+ tracked_match_score_.value);
+ }
+
+ // Add an absurd amount of missed frames so that all heuristics will
+ // consider it a lost track.
+ num_consecutive_frames_below_threshold_ += 1000;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackLost();
+ }
+ } else {
+ ++num_consecutive_frames_below_threshold_;
+ allowable_detection_distance_ *= 1.1f;
+ }
+}
+
+void TrackedObject::OnDetection(ObjectModelBase* const model,
+ const BoundingBox& detection_position,
+ const MatchScore match_score,
+ const int64 timestamp,
+ const ImageData& image_data) {
+ const float overlap = detection_position.PascalScore(last_known_position_);
+ if (overlap > kPositionOverlapThreshold) {
+ // If the position agreement with the current tracked position is good
+ // enough, lock all the current unlocked examples.
+ object_model_->TrackConfirmed();
+ num_consecutive_frames_below_threshold_ = 0;
+ }
+
+ // Before relocalizing, make sure the new proposed position is better than
+ // the existing position by a small amount to prevent thrashing.
+ if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
+ LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
+ match_score.value, tracked_match_score_.value,
+ kMatchScoreBuffer.value);
+ return;
+ }
+
+ LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
+ "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
+ last_known_position_.left_, last_known_position_.top_,
+ last_known_position_.GetWidth(), last_known_position_.GetHeight(),
+ detection_position.left_, detection_position.top_,
+ detection_position.GetWidth(), detection_position.GetHeight(),
+ match_score.value, tracked_match_score_.value);
+
+ if (overlap < kPositionOverlapThreshold) {
+ // The path might be good, it might be bad, but it's no longer a path
+ // since we're moving the box to a new position, so just nuke it from
+ // orbit to be safe.
+ object_model_->TrackLost();
+ }
+
+ object_model_ = model;
+
+ // Reset the last detected appearance.
+ InitNormalized(
+ *image_data.GetImage(), detection_position, &last_detection_thumbnail_);
+
+ num_consecutive_frames_below_threshold_ = 0;
+ last_detection_position_ = detection_position;
+
+ UpdatePosition(detection_position, timestamp, image_data, false);
+ allowable_detection_distance_ = Square(kInitialDistance);
+}
+
+} // namespace tf_tracking