diff options
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/object_tracker.h')
-rw-r--r-- | tensorflow/examples/android/jni/object_tracking/object_tracker.h | 271 |
1 files changed, 271 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h new file mode 100644 index 0000000000..3d2a9af360 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h @@ -0,0 +1,271 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ + +#include <map> +#include <string> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" + +namespace tf_tracking { + +typedef std::map<const std::string, TrackedObject*> TrackedObjectMap; + +inline std::ostream& operator<<(std::ostream& stream, + const TrackedObjectMap& map) { + for (TrackedObjectMap::const_iterator iter = map.begin(); + iter != map.end(); ++iter) { + const TrackedObject& tracked_object = *iter->second; + const std::string& key = iter->first; + stream << key << ": " << tracked_object; + } + return stream; +} + + +// ObjectTracker is the highest-level class in the tracking/detection framework. +// It handles basic image processing, keypoint detection, keypoint tracking, +// object tracking, and object detection/relocalization. +class ObjectTracker { + public: + ObjectTracker(const TrackerConfig* const config, + ObjectDetectorBase* const detector); + virtual ~ObjectTracker(); + + virtual void NextFrame(const uint8* const new_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3) { + NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); + } + + // Called upon the arrival of a new frame of raw data. + // Does all image processing, keypoint detection, and object + // tracking/detection for registered objects. + // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that + // represents the main transformation that has happened between the last + // and the current frame. + // Argument align_level is the pyramid level (where 0 == finest) that + // the matrix is valid for. + virtual void NextFrame(const uint8* const new_frame, + const uint8* const uv_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3); + + virtual void RegisterNewObjectWithAppearance( + const std::string& id, const uint8* const new_frame, + const BoundingBox& bounding_box); + + // Updates the position of a tracked object, given that it was known to be at + // a certain position at some point in the past. + virtual void SetPreviousPositionOfObject(const std::string& id, + const BoundingBox& bounding_box, + const int64 timestamp); + + // Sets the current position of the object in the most recent frame provided. + virtual void SetCurrentPositionOfObject(const std::string& id, + const BoundingBox& bounding_box); + + // Tells the ObjectTracker to stop tracking a target. + void ForgetTarget(const std::string& id); + + // Fills the given out_data buffer with the latest detected keypoint + // correspondences, first scaled by scale_factor (to adjust for downsampling + // that may have occurred elsewhere), then packed in a fixed-point format. + int GetKeypointsPacked(uint16* const out_data, + const float scale_factor) const; + + // Copy the keypoint arrays after computeFlow is called. + // out_data should be at least kMaxKeypoints * kKeypointStep long. + // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, + // where N is the number of keypoints tracked. N is returned as the result. + int GetKeypoints(const bool only_found, float* const out_data) const; + + // Returns the current position of a box, given that it was at a certain + // position at the given time. + BoundingBox TrackBox(const BoundingBox& region, + const int64 timestamp) const; + + // Returns the number of frames that have been passed to NextFrame(). + inline int GetNumFrames() const { + return num_frames_; + } + + inline bool HaveObject(const std::string& id) const { + return objects_.find(id) != objects_.end(); + } + + // Returns the TrackedObject associated with the given id. + inline const TrackedObject* GetObject(const std::string& id) const { + TrackedObjectMap::const_iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + // Returns the TrackedObject associated with the given id. + inline TrackedObject* GetObject(const std::string& id) { + TrackedObjectMap::iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + bool IsObjectVisible(const std::string& id) const { + SCHECK(HaveObject(id), "Don't have this object."); + + const TrackedObject* object = GetObject(id); + return object->IsVisible(); + } + + virtual void Draw(const int canvas_width, const int canvas_height, + const float* const frame_to_canvas) const; + + protected: + // Creates a new tracked object at the given position. + // If an object model is provided, then that model will be associated with the + // object. If not, a new model may be created from the appearance at the + // initial position and registered with the object detector. + virtual TrackedObject* MaybeAddObject(const std::string& id, + const Image<uint8>& image, + const BoundingBox& bounding_box, + const ObjectModelBase* object_model); + + // Find the keypoints in the frame before the current frame. + // If only one frame exists, keypoints will be found in that frame. + void ComputeKeypoints(const bool cached_ok = false); + + // Finds the correspondences for all the points in the current pair of frames. + // Stores the results in the given FramePair. + void FindCorrespondences(FramePair* const curr_change) const; + + inline int GetNthIndexFromEnd(const int offset) const { + return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); + } + + BoundingBox TrackBox(const BoundingBox& region, + const FramePair& frame_pair) const; + + inline void IncrementFrameIndex() { + // Move the current framechange index up. + ++num_frames_; + ++curr_num_frame_pairs_; + + // If we've got too many, push up the start of the queue. + if (curr_num_frame_pairs_ > kNumFrames) { + first_frame_index_ = GetNthIndexFromStart(1); + --curr_num_frame_pairs_; + } + } + + inline int GetNthIndexFromStart(const int offset) const { + SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, + "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); + return (first_frame_index_ + offset) % kNumFrames; + } + + void TrackObjects(); + + const std::unique_ptr<const TrackerConfig> config_; + + const int frame_width_; + const int frame_height_; + + int64 curr_time_; + + int num_frames_; + + TrackedObjectMap objects_; + + FlowCache flow_cache_; + + KeypointDetector keypoint_detector_; + + int curr_num_frame_pairs_; + int first_frame_index_; + + std::unique_ptr<ImageData> frame1_; + std::unique_ptr<ImageData> frame2_; + + FramePair frame_pairs_[kNumFrames]; + + std::unique_ptr<ObjectDetectorBase> detector_; + + int num_detected_; + + private: + void TrackTarget(TrackedObject* const object); + + bool GetBestObjectForDetection( + const Detection& detection, TrackedObject** match) const; + + void ProcessDetections(std::vector<Detection>* const detections); + + void DetectTargets(); + + // Temp object used in ObjectTracker::CreateNewExample. + mutable std::vector<BoundingSquare> squares; + + friend std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker); + + TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); +}; + +inline std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker) { + stream << "Frame size: " << tracker.frame_width_ << "x" + << tracker.frame_height_ << std::endl; + + stream << "Num frames: " << tracker.num_frames_ << std::endl; + + stream << "Curr time: " << tracker.curr_time_ << std::endl; + + const int first_frame_index = tracker.GetNthIndexFromStart(0); + const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; + + const int last_frame_index = tracker.GetNthIndexFromEnd(0); + const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; + + stream << "first frame: " << first_frame_index << "," + << first_frame_pair.end_time_ << " " + << "last frame: " << last_frame_index << "," + << last_frame_pair.end_time_ << " diff: " + << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" + << std::endl; + + stream << "Tracked targets:"; + stream << tracker.objects_; + + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ |