aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/flow_cache.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/flow_cache.h')
-rw-r--r--tensorflow/examples/android/jni/object_tracking/flow_cache.h306
1 files changed, 306 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
new file mode 100644
index 0000000000..8813ab6d71
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
@@ -0,0 +1,306 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+// Class that helps OpticalFlow to speed up flow computation
+// by caching coarse-grained flow.
+class FlowCache {
+ public:
+ explicit FlowCache(const OpticalFlowConfig* const config)
+ : config_(config),
+ image_size_(config->image_size),
+ optical_flow_(config),
+ fullframe_matrix_(NULL) {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ const int curr_dims = BlockDimForCacheLevel(i);
+ has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
+ displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
+ }
+ }
+
+ ~FlowCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ SAFE_DELETE(has_cache_[i]);
+ SAFE_DELETE(displacements_[i]);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ void NextFrame(ImageData* const new_frame,
+ const float* const align_matrix23) {
+ ClearCache();
+ SetFullframeAlignmentMatrix(align_matrix23);
+ optical_flow_.NextFrame(new_frame);
+ }
+
+ void ClearCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ has_cache_[i]->Clear(false);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ // Finds the flow at a point, using the cache for performance.
+ bool FindFlowAtPoint(const float u_x, const float u_y,
+ float* const flow_x, float* const flow_y) const {
+ // Get the best guess from the cache.
+ const Point2f guess_from_cache = LookupGuess(u_x, u_y);
+
+ *flow_x = guess_from_cache.x;
+ *flow_y = guess_from_cache.y;
+
+ // Now refine the guess using the image pyramid.
+ for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
+ pyramid_level >= 0; --pyramid_level) {
+ if (!optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Determines the displacement of a point, and uses that to calculate a new
+ // position.
+ // Returns true iff the displacement determination worked and the new position
+ // is in the image.
+ bool FindNewPositionOfPoint(const float u_x, const float u_y,
+ float* final_x, float* final_y) const {
+ float flow_x;
+ float flow_y;
+ if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
+ return false;
+ }
+
+ // Add in the displacement to get the final position.
+ *final_x = u_x + flow_x;
+ *final_y = u_y + flow_y;
+
+ // Assign the best guess, if we're still in the image.
+ if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
+ InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // Comparison function for qsort.
+ static int Compare(const void* a, const void* b) {
+ return *reinterpret_cast<const float*>(a) -
+ *reinterpret_cast<const float*>(b);
+ }
+
+ // Returns the median flow within the given bounding box as determined
+ // by a grid_width x grid_height grid.
+ Point2f GetMedianFlow(const BoundingBox& bounding_box,
+ const bool filter_by_fb_error,
+ const int grid_width,
+ const int grid_height) const {
+ const int kMaxPoints = 100;
+ SCHECK(grid_width * grid_height <= kMaxPoints,
+ "Too many points for Median flow!");
+
+ const BoundingBox valid_box = bounding_box.Intersect(
+ BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
+
+ if (valid_box.GetArea() <= 0.0f) {
+ return Point2f(0, 0);
+ }
+
+ float x_deltas[kMaxPoints];
+ float y_deltas[kMaxPoints];
+
+ int curr_offset = 0;
+ for (int i = 0; i < grid_width; ++i) {
+ for (int j = 0; j < grid_height; ++j) {
+ const float x_in = valid_box.left_ +
+ (valid_box.GetWidth() * i) / (grid_width - 1);
+
+ const float y_in = valid_box.top_ +
+ (valid_box.GetHeight() * j) / (grid_height - 1);
+
+ float curr_flow_x;
+ float curr_flow_y;
+ const bool success = FindNewPositionOfPoint(x_in, y_in,
+ &curr_flow_x, &curr_flow_y);
+
+ if (success) {
+ x_deltas[curr_offset] = curr_flow_x;
+ y_deltas[curr_offset] = curr_flow_y;
+ ++curr_offset;
+ } else {
+ LOGW("Tracking failure!");
+ }
+ }
+ }
+
+ if (curr_offset > 0) {
+ qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
+ qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
+
+ return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
+ }
+
+ LOGW("No points were valid!");
+ return Point2f(0, 0);
+ }
+
+ void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
+ if (align_matrix23 != NULL) {
+ if (fullframe_matrix_ == NULL) {
+ fullframe_matrix_ = new float[6];
+ }
+
+ memcpy(fullframe_matrix_, align_matrix23,
+ 6 * sizeof(fullframe_matrix_[0]));
+ }
+ }
+
+ private:
+ Point2f LookupGuessFromLevel(
+ const int cache_level, const float x, const float y) const {
+ // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
+
+ // Cutoff at the target level and use the matrix transform instead.
+ if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
+ const float xnew = x * fullframe_matrix_[0] +
+ y * fullframe_matrix_[1] +
+ fullframe_matrix_[2];
+ const float ynew = x * fullframe_matrix_[3] +
+ y * fullframe_matrix_[4] +
+ fullframe_matrix_[5];
+
+ return Point2f(xnew - x, ynew - y);
+ }
+
+ const int level_dim = BlockDimForCacheLevel(cache_level);
+ const int pixels_per_cache_block_x =
+ (image_size_.width + level_dim - 1) / level_dim;
+ const int pixels_per_cache_block_y =
+ (image_size_.height + level_dim - 1) / level_dim;
+ const int index_x = x / pixels_per_cache_block_x;
+ const int index_y = y / pixels_per_cache_block_y;
+
+ Point2f displacement;
+ if (!(*has_cache_[cache_level])[index_y][index_x]) {
+ (*has_cache_[cache_level])[index_y][index_x] = true;
+
+ // Get the lower cache level's best guess, if it exists.
+ displacement = cache_level >= kNumCacheLevels - 1 ?
+ Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
+ // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
+ // best_guess.x, best_guess.y);
+
+ // Find the center of the block.
+ const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
+ const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
+ const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
+
+ // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
+ // "Querying %5.2f, %5.2f at pyramid level %d, ",
+ // cache_level, index_x, index_y,
+ // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
+ // center_x, center_y, pyramid_level);
+
+ // TODO(andrewharp): Turn on FB error filtering.
+ const bool success = optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, center_x, center_y, false,
+ &displacement.x, &displacement.y);
+
+ if (!success) {
+ LOGV("Computation of cached value failed for level %d!", cache_level);
+ }
+
+ // Store the value for later use.
+ (*displacements_[cache_level])[index_y][index_x] = displacement;
+ } else {
+ displacement = (*displacements_[cache_level])[index_y][index_x];
+ }
+
+ // LOGI("Returning %5.2f, %5.2f for level %d",
+ // displacement.x, displacement.y, cache_level);
+ return displacement;
+ }
+
+ Point2f LookupGuess(const float x, const float y) const {
+ if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
+ return Point2f(0, 0);
+ }
+
+ // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
+ if (kNumCacheLevels > 0) {
+ return LookupGuessFromLevel(0, x, y);
+ } else {
+ return Point2f(0, 0);
+ }
+ }
+
+ // Returns the number of cache bins in each dimension for a given level
+ // of the cache.
+ int BlockDimForCacheLevel(const int cache_level) const {
+ // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
+ // thus if there are 4 cache levels, requesting level 3 (0-based) should
+ // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
+ // and so on.
+ int block_dim = kNumCacheLevels;
+ for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
+ --curr_level) {
+ block_dim *= kCacheBranchFactor;
+ }
+ return block_dim;
+ }
+
+ // Returns the level of the image pyramid that a given cache level maps to.
+ int PyramidLevelForCacheLevel(const int cache_level) const {
+ // Higher cache and pyramid levels have smaller dimensions. The highest
+ // cache level should refer to the highest image pyramid level. The
+ // lower, finer image pyramid levels are uncached (assuming
+ // kNumCacheLevels < kNumPyramidLevels).
+ return cache_level + (kNumPyramidLevels - kNumCacheLevels);
+ }
+
+ const OpticalFlowConfig* const config_;
+
+ const Size image_size_;
+ OpticalFlow optical_flow_;
+
+ float* fullframe_matrix_;
+
+ // Whether this value is currently present in the cache.
+ Image<bool>* has_cache_[kNumCacheLevels];
+
+ // The cached displacement values.
+ Image<Point2f>* displacements_[kNumCacheLevels];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_