aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/object_detector.h
blob: 043f606e1d9a01c4836d7d161031289ee9146bc2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// NOTE: no native object detectors are currently provided or used by the code
// in this directory. This class remains mainly for historical reasons.
// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.

// Defines the ObjectDetector class that is the main interface for detecting
// ObjectModelBases in frames.

#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_

#include <float.h>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#ifdef __RENDER_OPENGL__
#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
#endif
#include "tensorflow/examples/android/jni/object_tracking/utils.h"

#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/object_model.h"

namespace tf_tracking {

// Adds BoundingSquares to a vector such that the first square added is centered
// in the position given and of square_size, and the remaining squares are added
// concentrentically, scaling down by scale_factor until the minimum threshold
// size is passed.
// Squares that do not fall completely within image_bounds will not be added.
static inline void FillWithSquares(
    const BoundingBox& image_bounds,
    const BoundingBox& position,
    const float starting_square_size,
    const float smallest_square_size,
    const float scale_factor,
    std::vector<BoundingSquare>* const squares) {
  BoundingSquare descriptor_area =
      GetCenteredSquare(position, starting_square_size);

  SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);

  // Use a do/while loop to ensure that at least one descriptor is created.
  do {
    if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
      squares->push_back(descriptor_area);
    }
    descriptor_area.Scale(scale_factor);
  } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
  LOGV("Created %zu squares starting from size %.2f to min size %.2f "
       "using scale factor: %.2f",
       squares->size(), starting_square_size, smallest_square_size,
       scale_factor);
}


// Represents a potential detection of a specific ObjectExemplar and Descriptor
// at a specific position in the image.
class Detection {
 public:
  explicit Detection(const ObjectModelBase* const object_model,
                     const MatchScore match_score,
                     const BoundingBox& bounding_box)
      : object_model_(object_model),
        match_score_(match_score),
        bounding_box_(bounding_box) {}

  Detection(const Detection& other)
      : object_model_(other.object_model_),
        match_score_(other.match_score_),
        bounding_box_(other.bounding_box_) {}

  virtual ~Detection() {}

  inline BoundingBox GetObjectBoundingBox() const {
    return bounding_box_;
  }

  inline MatchScore GetMatchScore() const {
    return match_score_;
  }

  inline const ObjectModelBase* GetObjectModel() const {
    return object_model_;
  }

  inline bool Intersects(const Detection& other) {
    // Check if any of the four axes separates us, there must be at least one.
    return bounding_box_.Intersects(other.bounding_box_);
  }

  struct Comp {
    inline bool operator()(const Detection& a, const Detection& b) const {
      return a.match_score_ > b.match_score_;
    }
  };

  // TODO(andrewharp): add accessors to update these instead.
  const ObjectModelBase* object_model_;
  MatchScore match_score_;
  BoundingBox bounding_box_;
};

inline std::ostream& operator<<(std::ostream& stream,
                                const Detection& detection) {
  const BoundingBox actual_area = detection.GetObjectBoundingBox();
  stream << actual_area;
  return stream;
}

class ObjectDetectorBase {
 public:
  explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
      : config_(config),
        image_data_(NULL) {}

  virtual ~ObjectDetectorBase();

  // Sets the current image data. All calls to ObjectDetector other than
  // FillDescriptors use the image data last set.
  inline void SetImageData(const ImageData* const image_data) {
    image_data_ = image_data;
  }

  // Main entry point into the detection algorithm.
  // Scans the frame for candidates, tweaks them, and fills in the
  // given std::vector of Detection objects with acceptable matches.
  virtual void Detect(const std::vector<BoundingSquare>& positions,
                      std::vector<Detection>* const detections) const = 0;

  virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;

  virtual void DeleteObjectModel(const std::string& name) = 0;

  virtual void GetObjectModels(
      std::vector<const ObjectModelBase*>* models) const = 0;

  // Creates a new ObjectExemplar from the given position in the context of
  // the last frame passed to NextFrame.
  // Will return null in the case that there's no room for a descriptor to be
  // created in the example area, or the example area is not completely
  // contained within the frame.
  virtual void UpdateModel(
      const Image<uint8>& base_image,
      const IntegralImage& integral_image,
      const BoundingBox& bounding_box,
      const bool locked,
      ObjectModelBase* model) const = 0;

  virtual void Draw() const = 0;

  virtual bool AllowSpontaneousDetections() = 0;

 protected:
  const std::unique_ptr<const ObjectDetectorConfig> config_;

  // The latest frame data, upon which all detections will be performed.
  // Not owned by this object, just provided for reference by ObjectTracker
  // via SetImageData().
  const ImageData* image_data_;

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
};

template <typename ModelType>
class ObjectDetector : public ObjectDetectorBase {
 public:
  explicit ObjectDetector(const ObjectDetectorConfig* const config)
      : ObjectDetectorBase(config) {}

  virtual ~ObjectDetector() {
    typename std::map<std::string, ModelType*>::const_iterator it =
        object_models_.begin();
    for (; it != object_models_.end(); ++it) {
      ModelType* model = it->second;
      delete model;
    }
  }

  virtual void DeleteObjectModel(const std::string& name) {
    ModelType* model = object_models_[name];
    CHECK_ALWAYS(model != NULL, "Model was null!");
    object_models_.erase(name);
    SAFE_DELETE(model);
  }

  virtual void GetObjectModels(
      std::vector<const ObjectModelBase*>* models) const {
    typename std::map<std::string, ModelType*>::const_iterator it =
        object_models_.begin();
    for (; it != object_models_.end(); ++it) {
      models->push_back(it->second);
    }
  }

  virtual bool AllowSpontaneousDetections() {
    return false;
  }

 protected:
  std::map<std::string, ModelType*> object_models_;

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
};

}  // namespace tf_tracking

#endif  // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_