aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm')
-rw-r--r--tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm143
1 files changed, 98 insertions, 45 deletions
diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm
index 7a5dc31a22..43746882ee 100644
--- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm
+++ b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm
@@ -16,9 +16,9 @@
#include "tensorflow_utils.h"
-#include <fstream>
#include <pthread.h>
#include <unistd.h>
+#include <fstream>
#include <queue>
#include <sstream>
#include <string>
@@ -35,56 +35,58 @@
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
-
namespace {
- class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
- public:
- explicit IfstreamInputStream(const std::string& file_name)
- : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {}
- ~IfstreamInputStream() { ifs_.close(); }
-
- int Read(void* buffer, int size) {
- if (!ifs_) {
- return -1;
- }
- ifs_.read(static_cast<char*>(buffer), size);
- return ifs_.gcount();
+
+// Helper class used to load protobufs efficiently.
+class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
+ public:
+ explicit IfstreamInputStream(const std::string& file_name)
+ : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {}
+ ~IfstreamInputStream() { ifs_.close(); }
+
+ int Read(void* buffer, int size) {
+ if (!ifs_) {
+ return -1;
}
-
- private:
- std::ifstream ifs_;
- };
+ ifs_.read(static_cast<char*>(buffer), size);
+ return ifs_.gcount();
+ }
+
+ private:
+ std::ifstream ifs_;
+};
} // namespace
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
void GetTopN(const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
- Eigen::Aligned>& prediction, const int num_results,
- const float threshold,
+ Eigen::Aligned>& prediction,
+ const int num_results, const float threshold,
std::vector<std::pair<float, int> >* top_results) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>,
- std::vector<std::pair<float, int> >,
- std::greater<std::pair<float, int> > > top_result_pq;
-
+ std::vector<std::pair<float, int> >,
+ std::greater<std::pair<float, int> > >
+ top_result_pq;
+
const int count = prediction.size();
for (int i = 0; i < count; ++i) {
const float value = prediction(i);
-
+
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {
continue;
}
-
+
top_result_pq.push(std::pair<float, int>(value, i));
-
+
// If at capacity, kick the smallest value out.
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
-
+
// Copy to output vector and reverse into descending order.
while (!top_result_pq.empty()) {
top_results->push_back(top_result_pq.top());
@@ -93,11 +95,10 @@ void GetTopN(const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
std::reverse(top_results->begin(), top_results->end());
}
-
bool PortableReadFileToProto(const std::string& file_name,
::google::protobuf::MessageLite* proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
- new IfstreamInputStream(file_name));
+ new IfstreamInputStream(file_name));
stream.SetOwnsCopyingStream(true);
::google::protobuf::io::CodedInputStream coded_stream(&stream);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
@@ -107,10 +108,11 @@ bool PortableReadFileToProto(const std::string& file_name,
}
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
- NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
+ NSString* file_path =
+ [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
- << [extension UTF8String] << "' in bundle.";
+ << [extension UTF8String] << "' in bundle.";
return nullptr;
}
return file_path;
@@ -119,19 +121,18 @@ NSString* FilePathForResourceName(NSString* name, NSString* extension) {
tensorflow::Status LoadModel(NSString* file_name, NSString* file_type,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::SessionOptions options;
-
+
tensorflow::Session* session_pointer = nullptr;
- tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer);
+ tensorflow::Status session_status =
+ tensorflow::NewSession(options, &session_pointer);
if (!session_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
return session_status;
}
session->reset(session_pointer);
- LOG(INFO) << "Session created.";
-
+
tensorflow::GraphDef tensorflow_graph;
- LOG(INFO) << "Graph created.";
-
+
NSString* model_path = FilePathForResourceName(file_name, file_type);
if (!model_path) {
LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String]
@@ -139,37 +140,89 @@ tensorflow::Status LoadModel(NSString* file_name, NSString* file_type,
return tensorflow::errors::NotFound([file_name UTF8String],
[file_type UTF8String]);
}
- const bool read_proto_succeeded = PortableReadFileToProto(
- [model_path UTF8String], &tensorflow_graph);
+ const bool read_proto_succeeded =
+ PortableReadFileToProto([model_path UTF8String], &tensorflow_graph);
if (!read_proto_succeeded) {
LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String];
return tensorflow::errors::NotFound([model_path UTF8String]);
}
-
- LOG(INFO) << "Creating session.";
+
tensorflow::Status create_status = (*session)->Create(tensorflow_graph);
if (!create_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
return create_status;
}
-
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status LoadMemoryMappedModel(
+ NSString* file_name, NSString* file_type,
+ std::unique_ptr<tensorflow::Session>* session,
+ std::unique_ptr<tensorflow::MemmappedEnv>* memmapped_env) {
+ NSString* network_path = FilePathForResourceName(file_name, file_type);
+ memmapped_env->reset(
+ new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
+ tensorflow::Status mmap_status =
+ (memmapped_env->get())->InitializeFromFile([network_path UTF8String]);
+ if (!mmap_status.ok()) {
+ LOG(ERROR) << "MMap failed with " << mmap_status.error_message();
+ return mmap_status;
+ }
+
+ tensorflow::GraphDef tensorflow_graph;
+ tensorflow::Status load_graph_status = ReadBinaryProto(
+ memmapped_env->get(),
+ tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
+ &tensorflow_graph);
+ if (!load_graph_status.ok()) {
+ LOG(ERROR) << "MMap load graph failed with "
+ << load_graph_status.error_message();
+ return load_graph_status;
+ }
+
+ tensorflow::SessionOptions options;
+ // Disable optimizations on this graph so that constant folding doesn't
+ // increase the memory footprint by creating new constant copies of the weight
+ // parameters.
+ options.config.mutable_graph_options()
+ ->mutable_optimizer_options()
+ ->set_opt_level(::tensorflow::OptimizerOptions::L0);
+ options.env = memmapped_env->get();
+
+ tensorflow::Session* session_pointer = nullptr;
+ tensorflow::Status session_status =
+ tensorflow::NewSession(options, &session_pointer);
+ if (!session_status.ok()) {
+ LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
+ return session_status;
+ }
+
+ tensorflow::Status create_status = session_pointer->Create(tensorflow_graph);
+ if (!create_status.ok()) {
+ LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
+ return create_status;
+ }
+
+ session->reset(session_pointer);
+
return tensorflow::Status::OK();
}
tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type,
- std::vector<std::string>* label_strings) {
+ std::vector<std::string>* label_strings) {
// Read the label list
NSString* labels_path = FilePathForResourceName(file_name, file_type);
if (!labels_path) {
LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String]
- << [file_type UTF8String];
+ << [file_type UTF8String];
return tensorflow::errors::NotFound([file_name UTF8String],
[file_type UTF8String]);
}
std::ifstream t;
t.open([labels_path UTF8String]);
std::string line;
- while(t){
+ while (t) {
std::getline(t, line);
label_strings->push_back(line);
}