aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/object_tracker.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/object_tracker.h')
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.h271
1 files changed, 271 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
new file mode 100644
index 0000000000..3d2a9af360
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
@@ -0,0 +1,271 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TrackedObjectMap& map) {
+ for (TrackedObjectMap::const_iterator iter = map.begin();
+ iter != map.end(); ++iter) {
+ const TrackedObject& tracked_object = *iter->second;
+ const std::string& key = iter->first;
+ stream << key << ": " << tracked_object;
+ }
+ return stream;
+}
+
+
+// ObjectTracker is the highest-level class in the tracking/detection framework.
+// It handles basic image processing, keypoint detection, keypoint tracking,
+// object tracking, and object detection/relocalization.
+class ObjectTracker {
+ public:
+ ObjectTracker(const TrackerConfig* const config,
+ ObjectDetectorBase* const detector);
+ virtual ~ObjectTracker();
+
+ virtual void NextFrame(const uint8* const new_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3) {
+ NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
+ }
+
+ // Called upon the arrival of a new frame of raw data.
+ // Does all image processing, keypoint detection, and object
+ // tracking/detection for registered objects.
+ // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
+ // represents the main transformation that has happened between the last
+ // and the current frame.
+ // Argument align_level is the pyramid level (where 0 == finest) that
+ // the matrix is valid for.
+ virtual void NextFrame(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3);
+
+ virtual void RegisterNewObjectWithAppearance(
+ const std::string& id, const uint8* const new_frame,
+ const BoundingBox& bounding_box);
+
+ // Updates the position of a tracked object, given that it was known to be at
+ // a certain position at some point in the past.
+ virtual void SetPreviousPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box,
+ const int64 timestamp);
+
+ // Sets the current position of the object in the most recent frame provided.
+ virtual void SetCurrentPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box);
+
+ // Tells the ObjectTracker to stop tracking a target.
+ void ForgetTarget(const std::string& id);
+
+ // Fills the given out_data buffer with the latest detected keypoint
+ // correspondences, first scaled by scale_factor (to adjust for downsampling
+ // that may have occurred elsewhere), then packed in a fixed-point format.
+ int GetKeypointsPacked(uint16* const out_data,
+ const float scale_factor) const;
+
+ // Copy the keypoint arrays after computeFlow is called.
+ // out_data should be at least kMaxKeypoints * kKeypointStep long.
+ // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
+ // where N is the number of keypoints tracked. N is returned as the result.
+ int GetKeypoints(const bool only_found, float* const out_data) const;
+
+ // Returns the current position of a box, given that it was at a certain
+ // position at the given time.
+ BoundingBox TrackBox(const BoundingBox& region,
+ const int64 timestamp) const;
+
+ // Returns the number of frames that have been passed to NextFrame().
+ inline int GetNumFrames() const {
+ return num_frames_;
+ }
+
+ inline bool HaveObject(const std::string& id) const {
+ return objects_.find(id) != objects_.end();
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline const TrackedObject* GetObject(const std::string& id) const {
+ TrackedObjectMap::const_iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline TrackedObject* GetObject(const std::string& id) {
+ TrackedObjectMap::iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ bool IsObjectVisible(const std::string& id) const {
+ SCHECK(HaveObject(id), "Don't have this object.");
+
+ const TrackedObject* object = GetObject(id);
+ return object->IsVisible();
+ }
+
+ virtual void Draw(const int canvas_width, const int canvas_height,
+ const float* const frame_to_canvas) const;
+
+ protected:
+ // Creates a new tracked object at the given position.
+ // If an object model is provided, then that model will be associated with the
+ // object. If not, a new model may be created from the appearance at the
+ // initial position and registered with the object detector.
+ virtual TrackedObject* MaybeAddObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ const ObjectModelBase* object_model);
+
+ // Find the keypoints in the frame before the current frame.
+ // If only one frame exists, keypoints will be found in that frame.
+ void ComputeKeypoints(const bool cached_ok = false);
+
+ // Finds the correspondences for all the points in the current pair of frames.
+ // Stores the results in the given FramePair.
+ void FindCorrespondences(FramePair* const curr_change) const;
+
+ inline int GetNthIndexFromEnd(const int offset) const {
+ return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
+ }
+
+ BoundingBox TrackBox(const BoundingBox& region,
+ const FramePair& frame_pair) const;
+
+ inline void IncrementFrameIndex() {
+ // Move the current framechange index up.
+ ++num_frames_;
+ ++curr_num_frame_pairs_;
+
+ // If we've got too many, push up the start of the queue.
+ if (curr_num_frame_pairs_ > kNumFrames) {
+ first_frame_index_ = GetNthIndexFromStart(1);
+ --curr_num_frame_pairs_;
+ }
+ }
+
+ inline int GetNthIndexFromStart(const int offset) const {
+ SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
+ "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
+ return (first_frame_index_ + offset) % kNumFrames;
+ }
+
+ void TrackObjects();
+
+ const std::unique_ptr<const TrackerConfig> config_;
+
+ const int frame_width_;
+ const int frame_height_;
+
+ int64 curr_time_;
+
+ int num_frames_;
+
+ TrackedObjectMap objects_;
+
+ FlowCache flow_cache_;
+
+ KeypointDetector keypoint_detector_;
+
+ int curr_num_frame_pairs_;
+ int first_frame_index_;
+
+ std::unique_ptr<ImageData> frame1_;
+ std::unique_ptr<ImageData> frame2_;
+
+ FramePair frame_pairs_[kNumFrames];
+
+ std::unique_ptr<ObjectDetectorBase> detector_;
+
+ int num_detected_;
+
+ private:
+ void TrackTarget(TrackedObject* const object);
+
+ bool GetBestObjectForDetection(
+ const Detection& detection, TrackedObject** match) const;
+
+ void ProcessDetections(std::vector<Detection>* const detections);
+
+ void DetectTargets();
+
+ // Temp object used in ObjectTracker::CreateNewExample.
+ mutable std::vector<BoundingSquare> squares;
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker) {
+ stream << "Frame size: " << tracker.frame_width_ << "x"
+ << tracker.frame_height_ << std::endl;
+
+ stream << "Num frames: " << tracker.num_frames_ << std::endl;
+
+ stream << "Curr time: " << tracker.curr_time_ << std::endl;
+
+ const int first_frame_index = tracker.GetNthIndexFromStart(0);
+ const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
+
+ const int last_frame_index = tracker.GetNthIndexFromEnd(0);
+ const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
+
+ stream << "first frame: " << first_frame_index << ","
+ << first_frame_pair.end_time_ << " "
+ << "last frame: " << last_frame_index << ","
+ << last_frame_pair.end_time_ << " diff: "
+ << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
+ << std::endl;
+
+ stream << "Tracked targets:";
+ stream << tracker.objects_;
+
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_