diff options
author | Pete Warden <pete@petewarden.com> | 2016-02-12 15:50:54 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-12 16:43:39 -0800 |
commit | 2179890199c3561ff3a1297c5e9c073471473a77 (patch) | |
tree | 4a1c5e20734a1ca2ea0bbc4b757e5cc329e82e96 /tensorflow/examples/label_image | |
parent | 80d0a94f66f381fb30799035a88c8f0b39a63cd4 (diff) |
Transfer learning example, retraining Inception to recognize flowers.
Change: 114578030
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r-- | tensorflow/examples/label_image/main.cc | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 424a594059..78fd1dd860 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -60,7 +60,8 @@ using tensorflow::int32; // Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length // of the result is a multiple of 16, because our model expects that. -Status ReadLabelsFile(string file_name, std::vector<string>* result) { +Status ReadLabelsFile(string file_name, std::vector<string>* result, + size_t* found_label_count) { std::ifstream file(file_name); if (!file) { return tensorflow::errors::NotFound("Labels file ", file_name, @@ -71,6 +72,7 @@ Status ReadLabelsFile(string file_name, std::vector<string>* result) { while (std::getline(file, line)) { result->push_back(line); } + *found_label_count = result->size(); const int padding = 16; while (result->size() % padding) { result->emplace_back(); @@ -146,7 +148,6 @@ Status LoadGraph(string graph_file_name, return tensorflow::errors::NotFound("Failed to load compute graph at '", graph_file_name, "'"); } - session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); Status session_create_status = (*session)->Create(graph_def); if (!session_create_status.ok()) { @@ -186,12 +187,14 @@ Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels, Status PrintTopLabels(const std::vector<Tensor>& outputs, string labels_file_name) { std::vector<string> labels; - Status read_labels_status = ReadLabelsFile(labels_file_name, &labels); + size_t label_count; + Status read_labels_status = + ReadLabelsFile(labels_file_name, &labels, &label_count); if (!read_labels_status.ok()) { LOG(ERROR) << read_labels_status; return read_labels_status; } - const int how_many_labels = 5; + const int how_many_labels = std::min(5, static_cast<int>(label_count)); Tensor indices; Tensor scores; TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores)); |