aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/label_image
diff options
context:
space:
mode:
authorGravatar Pete Warden <pete@petewarden.com>2016-02-12 15:50:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-12 16:43:39 -0800
commit2179890199c3561ff3a1297c5e9c073471473a77 (patch)
tree4a1c5e20734a1ca2ea0bbc4b757e5cc329e82e96 /tensorflow/examples/label_image
parent80d0a94f66f381fb30799035a88c8f0b39a63cd4 (diff)
Transfer learning example, retraining Inception to recognize flowers.
Change: 114578030
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r--tensorflow/examples/label_image/main.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index 424a594059..78fd1dd860 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -60,7 +60,8 @@ using tensorflow::int32;
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
-Status ReadLabelsFile(string file_name, std::vector<string>* result) {
+Status ReadLabelsFile(string file_name, std::vector<string>* result,
+ size_t* found_label_count) {
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Labels file ", file_name,
@@ -71,6 +72,7 @@ Status ReadLabelsFile(string file_name, std::vector<string>* result) {
while (std::getline(file, line)) {
result->push_back(line);
}
+ *found_label_count = result->size();
const int padding = 16;
while (result->size() % padding) {
result->emplace_back();
@@ -146,7 +148,6 @@ Status LoadGraph(string graph_file_name,
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
-
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
@@ -186,12 +187,14 @@ Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
Status PrintTopLabels(const std::vector<Tensor>& outputs,
string labels_file_name) {
std::vector<string> labels;
- Status read_labels_status = ReadLabelsFile(labels_file_name, &labels);
+ size_t label_count;
+ Status read_labels_status =
+ ReadLabelsFile(labels_file_name, &labels, &label_count);
if (!read_labels_status.ok()) {
LOG(ERROR) << read_labels_status;
return read_labels_status;
}
- const int how_many_labels = 5;
+ const int how_many_labels = std::min(5, static_cast<int>(label_count));
Tensor indices;
Tensor scores;
TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));