diff options
Diffstat (limited to 'tensorflow/examples/android/jni/object_tracking/object_detector.h')
-rw-r--r-- | tensorflow/examples/android/jni/object_tracking/object_detector.h | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h new file mode 100644 index 0000000000..043f606e1d --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h @@ -0,0 +1,232 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +// Defines the ObjectDetector class that is the main interface for detecting +// ObjectModelBases in frames. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ + +#include <float.h> +#include <map> +#include <memory> +#include <sstream> +#include <string> +#include <vector> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" + +namespace tf_tracking { + +// Adds BoundingSquares to a vector such that the first square added is centered +// in the position given and of square_size, and the remaining squares are added +// concentrentically, scaling down by scale_factor until the minimum threshold +// size is passed. +// Squares that do not fall completely within image_bounds will not be added. +static inline void FillWithSquares( + const BoundingBox& image_bounds, + const BoundingBox& position, + const float starting_square_size, + const float smallest_square_size, + const float scale_factor, + std::vector<BoundingSquare>* const squares) { + BoundingSquare descriptor_area = + GetCenteredSquare(position, starting_square_size); + + SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor); + + // Use a do/while loop to ensure that at least one descriptor is created. + do { + if (image_bounds.Contains(descriptor_area.ToBoundingBox())) { + squares->push_back(descriptor_area); + } + descriptor_area.Scale(scale_factor); + } while (descriptor_area.size_ >= smallest_square_size - EPSILON); + LOGV("Created %zu squares starting from size %.2f to min size %.2f " + "using scale factor: %.2f", + squares->size(), starting_square_size, smallest_square_size, + scale_factor); +} + + +// Represents a potential detection of a specific ObjectExemplar and Descriptor +// at a specific position in the image. +class Detection { + public: + explicit Detection(const ObjectModelBase* const object_model, + const MatchScore match_score, + const BoundingBox& bounding_box) + : object_model_(object_model), + match_score_(match_score), + bounding_box_(bounding_box) {} + + Detection(const Detection& other) + : object_model_(other.object_model_), + match_score_(other.match_score_), + bounding_box_(other.bounding_box_) {} + + virtual ~Detection() {} + + inline BoundingBox GetObjectBoundingBox() const { + return bounding_box_; + } + + inline MatchScore GetMatchScore() const { + return match_score_; + } + + inline const ObjectModelBase* GetObjectModel() const { + return object_model_; + } + + inline bool Intersects(const Detection& other) { + // Check if any of the four axes separates us, there must be at least one. + return bounding_box_.Intersects(other.bounding_box_); + } + + struct Comp { + inline bool operator()(const Detection& a, const Detection& b) const { + return a.match_score_ > b.match_score_; + } + }; + + // TODO(andrewharp): add accessors to update these instead. + const ObjectModelBase* object_model_; + MatchScore match_score_; + BoundingBox bounding_box_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const Detection& detection) { + const BoundingBox actual_area = detection.GetObjectBoundingBox(); + stream << actual_area; + return stream; +} + +class ObjectDetectorBase { + public: + explicit ObjectDetectorBase(const ObjectDetectorConfig* const config) + : config_(config), + image_data_(NULL) {} + + virtual ~ObjectDetectorBase(); + + // Sets the current image data. All calls to ObjectDetector other than + // FillDescriptors use the image data last set. + inline void SetImageData(const ImageData* const image_data) { + image_data_ = image_data; + } + + // Main entry point into the detection algorithm. + // Scans the frame for candidates, tweaks them, and fills in the + // given std::vector of Detection objects with acceptable matches. + virtual void Detect(const std::vector<BoundingSquare>& positions, + std::vector<Detection>* const detections) const = 0; + + virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0; + + virtual void DeleteObjectModel(const std::string& name) = 0; + + virtual void GetObjectModels( + std::vector<const ObjectModelBase*>* models) const = 0; + + // Creates a new ObjectExemplar from the given position in the context of + // the last frame passed to NextFrame. + // Will return null in the case that there's no room for a descriptor to be + // created in the example area, or the example area is not completely + // contained within the frame. + virtual void UpdateModel( + const Image<uint8>& base_image, + const IntegralImage& integral_image, + const BoundingBox& bounding_box, + const bool locked, + ObjectModelBase* model) const = 0; + + virtual void Draw() const = 0; + + virtual bool AllowSpontaneousDetections() = 0; + + protected: + const std::unique_ptr<const ObjectDetectorConfig> config_; + + // The latest frame data, upon which all detections will be performed. + // Not owned by this object, just provided for reference by ObjectTracker + // via SetImageData(). + const ImageData* image_data_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase); +}; + +template <typename ModelType> +class ObjectDetector : public ObjectDetectorBase { + public: + explicit ObjectDetector(const ObjectDetectorConfig* const config) + : ObjectDetectorBase(config) {} + + virtual ~ObjectDetector() { + typename std::map<std::string, ModelType*>::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + ModelType* model = it->second; + delete model; + } + } + + virtual void DeleteObjectModel(const std::string& name) { + ModelType* model = object_models_[name]; + CHECK_ALWAYS(model != NULL, "Model was null!"); + object_models_.erase(name); + SAFE_DELETE(model); + } + + virtual void GetObjectModels( + std::vector<const ObjectModelBase*>* models) const { + typename std::map<std::string, ModelType*>::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + models->push_back(it->second); + } + } + + virtual bool AllowSpontaneousDetections() { + return false; + } + + protected: + std::map<std::string, ModelType*> object_models_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ |