aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/object_tracker.h
blob: 3d2a9af360756937b430b06406a09223a519b1df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_

#include <map>
#include <string>

#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"

#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"

namespace tf_tracking {

typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;

inline std::ostream& operator<<(std::ostream& stream,
                                const TrackedObjectMap& map) {
  for (TrackedObjectMap::const_iterator iter = map.begin();
      iter != map.end(); ++iter) {
    const TrackedObject& tracked_object = *iter->second;
    const std::string& key = iter->first;
    stream << key << ": " << tracked_object;
  }
  return stream;
}


// ObjectTracker is the highest-level class in the tracking/detection framework.
// It handles basic image processing, keypoint detection, keypoint tracking,
// object tracking, and object detection/relocalization.
class ObjectTracker {
 public:
  ObjectTracker(const TrackerConfig* const config,
                ObjectDetectorBase* const detector);
  virtual ~ObjectTracker();

  virtual void NextFrame(const uint8* const new_frame,
                         const int64 timestamp,
                         const float* const alignment_matrix_2x3) {
    NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
  }

  // Called upon the arrival of a new frame of raw data.
  // Does all image processing, keypoint detection, and object
  // tracking/detection for registered objects.
  // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
  // represents the main transformation that has happened between the last
  // and the current frame.
  // Argument align_level is the pyramid level (where 0 == finest) that
  // the matrix is valid for.
  virtual void NextFrame(const uint8* const new_frame,
                         const uint8* const uv_frame,
                         const int64 timestamp,
                         const float* const alignment_matrix_2x3);

  virtual void RegisterNewObjectWithAppearance(
      const std::string& id, const uint8* const new_frame,
      const BoundingBox& bounding_box);

  // Updates the position of a tracked object, given that it was known to be at
  // a certain position at some point in the past.
  virtual void SetPreviousPositionOfObject(const std::string& id,
                                           const BoundingBox& bounding_box,
                                           const int64 timestamp);

  // Sets the current position of the object in the most recent frame provided.
  virtual void SetCurrentPositionOfObject(const std::string& id,
                                          const BoundingBox& bounding_box);

  // Tells the ObjectTracker to stop tracking a target.
  void ForgetTarget(const std::string& id);

  // Fills the given out_data buffer with the latest detected keypoint
  // correspondences, first scaled by scale_factor (to adjust for downsampling
  // that may have occurred elsewhere), then packed in a fixed-point format.
  int GetKeypointsPacked(uint16* const out_data,
                         const float scale_factor) const;

  // Copy the keypoint arrays after computeFlow is called.
  // out_data should be at least kMaxKeypoints * kKeypointStep long.
  // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
  // where N is the number of keypoints tracked.  N is returned as the result.
  int GetKeypoints(const bool only_found, float* const out_data) const;

  // Returns the current position of a box, given that it was at a certain
  // position at the given time.
  BoundingBox TrackBox(const BoundingBox& region,
                       const int64 timestamp) const;

  // Returns the number of frames that have been passed to NextFrame().
  inline int GetNumFrames() const {
    return num_frames_;
  }

  inline bool HaveObject(const std::string& id) const {
    return objects_.find(id) != objects_.end();
  }

  // Returns the TrackedObject associated with the given id.
  inline const TrackedObject* GetObject(const std::string& id) const {
    TrackedObjectMap::const_iterator iter = objects_.find(id);
    CHECK_ALWAYS(iter != objects_.end(),
                 "Unknown object key! \"%s\"", id.c_str());
    TrackedObject* const object = iter->second;
    return object;
  }

  // Returns the TrackedObject associated with the given id.
  inline TrackedObject* GetObject(const std::string& id) {
    TrackedObjectMap::iterator iter = objects_.find(id);
    CHECK_ALWAYS(iter != objects_.end(),
                 "Unknown object key! \"%s\"", id.c_str());
    TrackedObject* const object = iter->second;
    return object;
  }

  bool IsObjectVisible(const std::string& id) const {
    SCHECK(HaveObject(id), "Don't have this object.");

    const TrackedObject* object = GetObject(id);
    return object->IsVisible();
  }

  virtual void Draw(const int canvas_width, const int canvas_height,
                    const float* const frame_to_canvas) const;

 protected:
  // Creates a new tracked object at the given position.
  // If an object model is provided, then that model will be associated with the
  // object. If not, a new model may be created from the appearance at the
  // initial position and registered with the object detector.
  virtual TrackedObject* MaybeAddObject(const std::string& id,
                                        const Image<uint8>& image,
                                        const BoundingBox& bounding_box,
                                        const ObjectModelBase* object_model);

  // Find the keypoints in the frame before the current frame.
  // If only one frame exists, keypoints will be found in that frame.
  void ComputeKeypoints(const bool cached_ok = false);

  // Finds the correspondences for all the points in the current pair of frames.
  // Stores the results in the given FramePair.
  void FindCorrespondences(FramePair* const curr_change) const;

  inline int GetNthIndexFromEnd(const int offset) const {
    return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
  }

  BoundingBox TrackBox(const BoundingBox& region,
                       const FramePair& frame_pair) const;

  inline void IncrementFrameIndex() {
    // Move the current framechange index up.
    ++num_frames_;
    ++curr_num_frame_pairs_;

    // If we've got too many, push up the start of the queue.
    if (curr_num_frame_pairs_ > kNumFrames) {
      first_frame_index_ = GetNthIndexFromStart(1);
      --curr_num_frame_pairs_;
    }
  }

  inline int GetNthIndexFromStart(const int offset) const {
    SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
          "Offset out of range!  %d out of %d.", offset, curr_num_frame_pairs_);
    return (first_frame_index_ + offset) % kNumFrames;
  }

  void TrackObjects();

  const std::unique_ptr<const TrackerConfig> config_;

  const int frame_width_;
  const int frame_height_;

  int64 curr_time_;

  int num_frames_;

  TrackedObjectMap objects_;

  FlowCache flow_cache_;

  KeypointDetector keypoint_detector_;

  int curr_num_frame_pairs_;
  int first_frame_index_;

  std::unique_ptr<ImageData> frame1_;
  std::unique_ptr<ImageData> frame2_;

  FramePair frame_pairs_[kNumFrames];

  std::unique_ptr<ObjectDetectorBase> detector_;

  int num_detected_;

 private:
  void TrackTarget(TrackedObject* const object);

  bool GetBestObjectForDetection(
      const Detection& detection, TrackedObject** match) const;

  void ProcessDetections(std::vector<Detection>* const detections);

  void DetectTargets();

  // Temp object used in ObjectTracker::CreateNewExample.
  mutable std::vector<BoundingSquare> squares;

  friend std::ostream& operator<<(std::ostream& stream,
                                  const ObjectTracker& tracker);

  TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
};

inline std::ostream& operator<<(std::ostream& stream,
                                const ObjectTracker& tracker) {
  stream << "Frame size: " << tracker.frame_width_ << "x"
         << tracker.frame_height_ << std::endl;

  stream << "Num frames: " << tracker.num_frames_ << std::endl;

  stream << "Curr time: " << tracker.curr_time_ << std::endl;

  const int first_frame_index = tracker.GetNthIndexFromStart(0);
  const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];

  const int last_frame_index = tracker.GetNthIndexFromEnd(0);
  const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];

  stream << "first frame: " << first_frame_index << ","
         << first_frame_pair.end_time_ << "    "
         << "last frame: " << last_frame_index << ","
         << last_frame_pair.end_time_ << "   diff: "
         << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
         << std::endl;

  stream << "Tracked targets:";
  stream << tracker.objects_;

  return stream;
}

}  // namespace tf_tracking

#endif  // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_