aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/jni/object_tracking/image_data.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/image_data.h')
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_data.h270
1 files changed, 270 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h
new file mode 100644
index 0000000000..16b1864ee6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_data.h
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+// Class that encapsulates all bulky processed data for a frame.
+class ImageData {
+ public:
+ explicit ImageData(const int width, const int height)
+ : uv_frame_width_(width << 1),
+ uv_frame_height_(height << 1),
+ timestamp_(0),
+ image_(width, height) {
+ InitPyramid(width, height);
+ ResetComputationCache();
+ }
+
+ private:
+ void ResetComputationCache() {
+ uv_data_computed_ = false;
+ integral_image_computed_ = false;
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ spatial_x_computed_[i] = false;
+ spatial_y_computed_[i] = false;
+ pyramid_sqrt2_computed_[i * 2] = false;
+ pyramid_sqrt2_computed_[i * 2 + 1] = false;
+ }
+ }
+
+ void InitPyramid(const int width, const int height) {
+ int level_width = width;
+ int level_height = height;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ pyramid_sqrt2_[i * 2] = NULL;
+ pyramid_sqrt2_[i * 2 + 1] = NULL;
+ spatial_x_[i] = NULL;
+ spatial_y_[i] = NULL;
+
+ level_width /= 2;
+ level_height /= 2;
+ }
+
+ // Alias the first pyramid level to image_.
+ pyramid_sqrt2_[0] = &image_;
+ }
+
+ public:
+ ~ImageData() {
+ // The first pyramid level is actually an alias to image_,
+ // so make sure it doesn't get deleted here.
+ pyramid_sqrt2_[0] = NULL;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ SAFE_DELETE(pyramid_sqrt2_[i * 2]);
+ SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
+ SAFE_DELETE(spatial_x_[i]);
+ SAFE_DELETE(spatial_y_[i]);
+ }
+ }
+
+ void SetData(const uint8* const new_frame, const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ SetData(new_frame, NULL, stride, timestamp, downsample_factor);
+ }
+
+ void SetData(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ ResetComputationCache();
+
+ timestamp_ = timestamp;
+
+ TimeLog("SetData!");
+
+ pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
+ pyramid_sqrt2_computed_[0] = true;
+ TimeLog("Downsampled image");
+
+ if (uv_frame != NULL) {
+ if (u_data_.get() == NULL) {
+ u_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ v_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ }
+
+ GetUV(uv_frame, u_data_.get(), v_data_.get());
+ uv_data_computed_ = true;
+ TimeLog("Copied UV data");
+ } else {
+ LOGV("No uv data!");
+ }
+
+#ifdef LOG_TIME
+ // If profiling is enabled, precompute here to make it easier to distinguish
+ // total costs.
+ Precompute();
+#endif
+ }
+
+ inline const uint64 GetTimestamp() const {
+ return timestamp_;
+ }
+
+ inline const Image<uint8>* GetImage() const {
+ SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
+ return pyramid_sqrt2_[0];
+ }
+
+ const Image<uint8>* GetPyramidSqrt2Level(const int level) const {
+ if (!pyramid_sqrt2_computed_[level]) {
+ SCHECK(level != 0, "Level equals 0!");
+ if (level == 1) {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(0);
+ if (pyramid_sqrt2_[level] == NULL) {
+ const int new_width =
+ (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
+ const int new_height =
+ (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
+ 2;
+
+ pyramid_sqrt2_[level] = new Image<uint8>(new_width, new_height);
+ }
+ pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
+ } else {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(level - 2);
+ if (pyramid_sqrt2_[level] == NULL) {
+ pyramid_sqrt2_[level] = new Image<uint8>(
+ upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
+ }
+ pyramid_sqrt2_[level]->DownsampleAveraged(
+ upper_level.data(), upper_level.stride(), 2);
+ }
+ pyramid_sqrt2_computed_[level] = true;
+ }
+ return pyramid_sqrt2_[level];
+ }
+
+ inline const Image<int32>* GetSpatialX(const int level) const {
+ if (!spatial_x_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_x_[level] == NULL) {
+ spatial_x_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_x_[level]->DerivativeX(src);
+ spatial_x_computed_[level] = true;
+ }
+ return spatial_x_[level];
+ }
+
+ inline const Image<int32>* GetSpatialY(const int level) const {
+ if (!spatial_y_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_y_[level] == NULL) {
+ spatial_y_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_y_[level]->DerivativeY(src);
+ spatial_y_computed_[level] = true;
+ }
+ return spatial_y_[level];
+ }
+
+ // The integral image is currently only used for object detection, so lazily
+ // initialize it on request.
+ inline const IntegralImage* GetIntegralImage() const {
+ if (integral_image_.get() == NULL) {
+ integral_image_.reset(new IntegralImage(image_));
+ } else if (!integral_image_computed_) {
+ integral_image_->Recompute(image_);
+ }
+ integral_image_computed_ = true;
+ return integral_image_.get();
+ }
+
+ inline const Image<uint8>* GetU() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return u_data_.get();
+ }
+
+ inline const Image<uint8>* GetV() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return v_data_.get();
+ }
+
+ private:
+ void Precompute() {
+ // Create the smoothed pyramids.
+ for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed pyramids");
+
+ // Create the smoothed pyramids.
+ for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed sqrt pyramids");
+
+ // Create the spatial derivatives for frame 1.
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ (void) GetSpatialX(i);
+ (void) GetSpatialY(i);
+ }
+ TimeLog("Created spatial derivatives");
+
+ (void) GetIntegralImage();
+ TimeLog("Got integral image!");
+ }
+
+ const int uv_frame_width_;
+ const int uv_frame_height_;
+
+ int64 timestamp_;
+
+ Image<uint8> image_;
+
+ bool uv_data_computed_;
+ std::unique_ptr<Image<uint8> > u_data_;
+ std::unique_ptr<Image<uint8> > v_data_;
+
+ mutable bool spatial_x_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_x_[kNumPyramidLevels];
+
+ mutable bool spatial_y_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_y_[kNumPyramidLevels];
+
+ // Mutable so the lazy initialization can work when this class is const.
+ // Whether or not the integral image has been computed for the current image.
+ mutable bool integral_image_computed_;
+ mutable std::unique_ptr<IntegralImage> integral_image_;
+
+ mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
+ mutable Image<uint8>* pyramid_sqrt2_[kNumPyramidLevels * 2];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_