aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/flow_cache.h
blob: 8813ab6d71846f5ce2e13a2853594de43d95b0b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_

#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"

#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"

namespace tf_tracking {

// Class that helps OpticalFlow to speed up flow computation
// by caching coarse-grained flow.
class FlowCache {
 public:
  explicit FlowCache(const OpticalFlowConfig* const config)
      : config_(config),
        image_size_(config->image_size),
        optical_flow_(config),
        fullframe_matrix_(NULL) {
    for (int i = 0; i < kNumCacheLevels; ++i) {
      const int curr_dims = BlockDimForCacheLevel(i);
      has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
      displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
    }
  }

  ~FlowCache() {
    for (int i = 0; i < kNumCacheLevels; ++i) {
      SAFE_DELETE(has_cache_[i]);
      SAFE_DELETE(displacements_[i]);
    }
    delete[](fullframe_matrix_);
    fullframe_matrix_ = NULL;
  }

  void NextFrame(ImageData* const new_frame,
                 const float* const align_matrix23) {
    ClearCache();
    SetFullframeAlignmentMatrix(align_matrix23);
    optical_flow_.NextFrame(new_frame);
  }

  void ClearCache() {
    for (int i = 0; i < kNumCacheLevels; ++i) {
      has_cache_[i]->Clear(false);
    }
    delete[](fullframe_matrix_);
    fullframe_matrix_ = NULL;
  }

  // Finds the flow at a point, using the cache for performance.
  bool FindFlowAtPoint(const float u_x, const float u_y,
                       float* const flow_x, float* const flow_y) const {
    // Get the best guess from the cache.
    const Point2f guess_from_cache = LookupGuess(u_x, u_y);

    *flow_x = guess_from_cache.x;
    *flow_y = guess_from_cache.y;

    // Now refine the guess using the image pyramid.
    for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
        pyramid_level >= 0; --pyramid_level) {
      if (!optical_flow_.FindFlowAtPointSingleLevel(
          pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
        return false;
      }
    }

    return true;
  }

  // Determines the displacement of a point, and uses that to calculate a new
  // position.
  // Returns true iff the displacement determination worked and the new position
  // is in the image.
  bool FindNewPositionOfPoint(const float u_x, const float u_y,
                              float* final_x, float* final_y) const {
    float flow_x;
    float flow_y;
    if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
      return false;
    }

    // Add in the displacement to get the final position.
    *final_x = u_x + flow_x;
    *final_y = u_y + flow_y;

    // Assign the best guess, if we're still in the image.
    if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
        InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
      return true;
    } else {
      return false;
    }
  }

  // Comparison function for qsort.
  static int Compare(const void* a, const void* b) {
    return *reinterpret_cast<const float*>(a) -
           *reinterpret_cast<const float*>(b);
  }

  // Returns the median flow within the given bounding box as determined
  // by a grid_width x grid_height grid.
  Point2f GetMedianFlow(const BoundingBox& bounding_box,
                        const bool filter_by_fb_error,
                        const int grid_width,
                        const int grid_height) const {
    const int kMaxPoints = 100;
    SCHECK(grid_width * grid_height <= kMaxPoints,
          "Too many points for Median flow!");

    const BoundingBox valid_box = bounding_box.Intersect(
        BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));

    if (valid_box.GetArea() <= 0.0f) {
      return Point2f(0, 0);
    }

    float x_deltas[kMaxPoints];
    float y_deltas[kMaxPoints];

    int curr_offset = 0;
    for (int i = 0; i < grid_width; ++i) {
      for (int j = 0; j < grid_height; ++j) {
        const float x_in = valid_box.left_ +
            (valid_box.GetWidth() * i) / (grid_width - 1);

        const float y_in = valid_box.top_ +
            (valid_box.GetHeight() * j) / (grid_height - 1);

        float curr_flow_x;
        float curr_flow_y;
        const bool success = FindNewPositionOfPoint(x_in, y_in,
                                                    &curr_flow_x, &curr_flow_y);

        if (success) {
          x_deltas[curr_offset] = curr_flow_x;
          y_deltas[curr_offset] = curr_flow_y;
          ++curr_offset;
        } else {
          LOGW("Tracking failure!");
        }
      }
    }

    if (curr_offset > 0) {
      qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
      qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);

      return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
    }

    LOGW("No points were valid!");
    return Point2f(0, 0);
  }

  void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
    if (align_matrix23 != NULL) {
      if (fullframe_matrix_ == NULL) {
        fullframe_matrix_ = new float[6];
      }

      memcpy(fullframe_matrix_, align_matrix23,
             6 * sizeof(fullframe_matrix_[0]));
    }
  }

 private:
  Point2f LookupGuessFromLevel(
      const int cache_level, const float x, const float y) const {
    // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);

    // Cutoff at the target level and use the matrix transform instead.
    if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
      const float xnew = x * fullframe_matrix_[0] +
                         y * fullframe_matrix_[1] +
                             fullframe_matrix_[2];
      const float ynew = x * fullframe_matrix_[3] +
                         y * fullframe_matrix_[4] +
                             fullframe_matrix_[5];

      return Point2f(xnew - x, ynew - y);
    }

    const int level_dim = BlockDimForCacheLevel(cache_level);
    const int pixels_per_cache_block_x =
        (image_size_.width + level_dim - 1) / level_dim;
    const int pixels_per_cache_block_y =
        (image_size_.height + level_dim - 1) / level_dim;
    const int index_x = x / pixels_per_cache_block_x;
    const int index_y = y / pixels_per_cache_block_y;

    Point2f displacement;
    if (!(*has_cache_[cache_level])[index_y][index_x]) {
      (*has_cache_[cache_level])[index_y][index_x] = true;

      // Get the lower cache level's best guess, if it exists.
      displacement = cache_level >= kNumCacheLevels - 1 ?
          Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
      // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
      //      best_guess.x, best_guess.y);

      // Find the center of the block.
      const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
      const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
      const int pyramid_level = PyramidLevelForCacheLevel(cache_level);

      // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
      //      "Querying %5.2f, %5.2f at pyramid level %d, ",
      //      cache_level, index_x, index_y,
      //      x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
      //      center_x, center_y, pyramid_level);

      // TODO(andrewharp): Turn on FB error filtering.
      const bool success = optical_flow_.FindFlowAtPointSingleLevel(
          pyramid_level, center_x, center_y, false,
          &displacement.x, &displacement.y);

      if (!success) {
        LOGV("Computation of cached value failed for level %d!", cache_level);
      }

      // Store the value for later use.
      (*displacements_[cache_level])[index_y][index_x] = displacement;
    } else {
      displacement = (*displacements_[cache_level])[index_y][index_x];
    }

    // LOGI("Returning %5.2f, %5.2f for level %d",
    //      displacement.x, displacement.y, cache_level);
    return displacement;
  }

  Point2f LookupGuess(const float x, const float y) const {
    if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
      return Point2f(0, 0);
    }

    // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
    if (kNumCacheLevels > 0) {
      return LookupGuessFromLevel(0, x, y);
    } else {
      return Point2f(0, 0);
    }
  }

  // Returns the number of cache bins in each dimension for a given level
  // of the cache.
  int BlockDimForCacheLevel(const int cache_level) const {
    // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
    // thus if there are 4 cache levels, requesting level 3 (0-based) should
    // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
    // and so on.
    int block_dim = kNumCacheLevels;
    for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
        --curr_level) {
      block_dim *= kCacheBranchFactor;
    }
    return block_dim;
  }

  // Returns the level of the image pyramid that a given cache level maps to.
  int PyramidLevelForCacheLevel(const int cache_level) const {
    // Higher cache and pyramid levels have smaller dimensions. The highest
    // cache level should refer to the highest image pyramid level. The
    // lower, finer image pyramid levels are uncached (assuming
    // kNumCacheLevels < kNumPyramidLevels).
    return cache_level + (kNumPyramidLevels - kNumCacheLevels);
  }

  const OpticalFlowConfig* const config_;

  const Size image_size_;
  OpticalFlow optical_flow_;

  float* fullframe_matrix_;

  // Whether this value is currently present in the cache.
  Image<bool>* has_cache_[kNumCacheLevels];

  // The cached displacement values.
  Image<Point2f>* displacements_[kNumCacheLevels];

  TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
};

}  // namespace tf_tracking

#endif  // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_