diff options
Diffstat (limited to 'tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm')
-rw-r--r-- | tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm | 143 |
1 files changed, 98 insertions, 45 deletions
diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm index 7a5dc31a22..43746882ee 100644 --- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm +++ b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm @@ -16,9 +16,9 @@ #include "tensorflow_utils.h" -#include <fstream> #include <pthread.h> #include <unistd.h> +#include <fstream> #include <queue> #include <sstream> #include <string> @@ -35,56 +35,58 @@ #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" - namespace { - class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast<char*>(buffer), size); - return ifs_.gcount(); + +// Helper class used to load protobufs efficiently. +class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { + public: + explicit IfstreamInputStream(const std::string& file_name) + : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} + ~IfstreamInputStream() { ifs_.close(); } + + int Read(void* buffer, int size) { + if (!ifs_) { + return -1; } - - private: - std::ifstream ifs_; - }; + ifs_.read(static_cast<char*>(buffer), size); + return ifs_.gcount(); + } + + private: + std::ifstream ifs_; +}; } // namespace // Returns the top N confidence values over threshold in the provided vector, // sorted by confidence in descending order. void GetTopN(const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, - Eigen::Aligned>& prediction, const int num_results, - const float threshold, + Eigen::Aligned>& prediction, + const int num_results, const float threshold, std::vector<std::pair<float, int> >* top_results) { // Will contain top N results in ascending order. std::priority_queue<std::pair<float, int>, - std::vector<std::pair<float, int> >, - std::greater<std::pair<float, int> > > top_result_pq; - + std::vector<std::pair<float, int> >, + std::greater<std::pair<float, int> > > + top_result_pq; + const int count = prediction.size(); for (int i = 0; i < count; ++i) { const float value = prediction(i); - + // Only add it if it beats the threshold and has a chance at being in // the top N. if (value < threshold) { continue; } - + top_result_pq.push(std::pair<float, int>(value, i)); - + // If at capacity, kick the smallest value out. if (top_result_pq.size() > num_results) { top_result_pq.pop(); } } - + // Copy to output vector and reverse into descending order. while (!top_result_pq.empty()) { top_results->push_back(top_result_pq.top()); @@ -93,11 +95,10 @@ void GetTopN(const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, std::reverse(top_results->begin(), top_results->end()); } - bool PortableReadFileToProto(const std::string& file_name, ::google::protobuf::MessageLite* proto) { ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); + new IfstreamInputStream(file_name)); stream.SetOwnsCopyingStream(true); ::google::protobuf::io::CodedInputStream coded_stream(&stream); // Total bytes hard limit / warning limit are set to 1GB and 512MB @@ -107,10 +108,11 @@ bool PortableReadFileToProto(const std::string& file_name, } NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + NSString* file_path = + [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; + << [extension UTF8String] << "' in bundle."; return nullptr; } return file_path; @@ -119,19 +121,18 @@ NSString* FilePathForResourceName(NSString* name, NSString* extension) { tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, std::unique_ptr<tensorflow::Session>* session) { tensorflow::SessionOptions options; - + tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer); + tensorflow::Status session_status = + tensorflow::NewSession(options, &session_pointer); if (!session_status.ok()) { LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; return session_status; } session->reset(session_pointer); - LOG(INFO) << "Session created."; - + tensorflow::GraphDef tensorflow_graph; - LOG(INFO) << "Graph created."; - + NSString* model_path = FilePathForResourceName(file_name, file_type); if (!model_path) { LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] @@ -139,37 +140,89 @@ tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, return tensorflow::errors::NotFound([file_name UTF8String], [file_type UTF8String]); } - const bool read_proto_succeeded = PortableReadFileToProto( - [model_path UTF8String], &tensorflow_graph); + const bool read_proto_succeeded = + PortableReadFileToProto([model_path UTF8String], &tensorflow_graph); if (!read_proto_succeeded) { LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String]; return tensorflow::errors::NotFound([model_path UTF8String]); } - - LOG(INFO) << "Creating session."; + tensorflow::Status create_status = (*session)->Create(tensorflow_graph); if (!create_status.ok()) { LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; return create_status; } - + + return tensorflow::Status::OK(); +} + +tensorflow::Status LoadMemoryMappedModel( + NSString* file_name, NSString* file_type, + std::unique_ptr<tensorflow::Session>* session, + std::unique_ptr<tensorflow::MemmappedEnv>* memmapped_env) { + NSString* network_path = FilePathForResourceName(file_name, file_type); + memmapped_env->reset( + new tensorflow::MemmappedEnv(tensorflow::Env::Default())); + tensorflow::Status mmap_status = + (memmapped_env->get())->InitializeFromFile([network_path UTF8String]); + if (!mmap_status.ok()) { + LOG(ERROR) << "MMap failed with " << mmap_status.error_message(); + return mmap_status; + } + + tensorflow::GraphDef tensorflow_graph; + tensorflow::Status load_graph_status = ReadBinaryProto( + memmapped_env->get(), + tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &tensorflow_graph); + if (!load_graph_status.ok()) { + LOG(ERROR) << "MMap load graph failed with " + << load_graph_status.error_message(); + return load_graph_status; + } + + tensorflow::SessionOptions options; + // Disable optimizations on this graph so that constant folding doesn't + // increase the memory footprint by creating new constant copies of the weight + // parameters. + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(::tensorflow::OptimizerOptions::L0); + options.env = memmapped_env->get(); + + tensorflow::Session* session_pointer = nullptr; + tensorflow::Status session_status = + tensorflow::NewSession(options, &session_pointer); + if (!session_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; + return session_status; + } + + tensorflow::Status create_status = session_pointer->Create(tensorflow_graph); + if (!create_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; + return create_status; + } + + session->reset(session_pointer); + return tensorflow::Status::OK(); } tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, - std::vector<std::string>* label_strings) { + std::vector<std::string>* label_strings) { // Read the label list NSString* labels_path = FilePathForResourceName(file_name, file_type); if (!labels_path) { LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] - << [file_type UTF8String]; + << [file_type UTF8String]; return tensorflow::errors::NotFound([file_name UTF8String], [file_type UTF8String]); } std::ifstream t; t.open([labels_path UTF8String]); std::string line; - while(t){ + while (t) { std::getline(t, line); label_strings->push_back(line); } |