aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/label_image
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-01-06 12:54:03 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-06 12:54:03 -0800
commitb947bcc154db62e2306e8bf3edb1a94868edbbca (patch)
treeb3019f9edc459e53185cabe64ef27890fc1aebb8 /tensorflow/examples/label_image
parente26f0f34018b7adb519e989b59a0462b08a93ea8 (diff)
TensorFlow: Remove use of command_line_flags library in our usage. Soon
to be deleted completely. Change: 111521876
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r--tensorflow/examples/label_image/main.cc152
1 files changed, 111 insertions, 41 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index a5f14e4d6f..dd07e1e5cd 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -39,7 +39,6 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
-#include "tensorflow/core/lib/core/command_line_flags.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -56,32 +55,6 @@ using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
-// These are the command-line flags the program can understand.
-// They define where the graph and input data is located, and what kind of
-// input the model expects. If you train your own model, or use something
-// other than GoogLeNet you'll need to update these.
-TF_DEFINE_string(image,
- "tensorflow/examples/label_image/data/grace_hopper.jpg",
- "The image to classify (JPEG or PNG).");
-TF_DEFINE_string(graph,
- "tensorflow/examples/label_image/data/"
- "tensorflow_inception_graph.pb",
- "The location of the GraphDef file containing the protobuf"
- " definition of the network.");
-TF_DEFINE_string(labels,
- "tensorflow/examples/label_image/data/"
- "imagenet_comp_graph_label_strings.txt",
- "A text file containing the labels of all the categories, one"
- " per line.");
-TF_DEFINE_int32(input_width, 299, "Width of the image the network expects.");
-TF_DEFINE_int32(input_height, 299, "Height of the image the network expects.");
-TF_DEFINE_int32(input_mean, 128, "How much to subtract from input values.");
-TF_DEFINE_int32(input_std, 128, "What to divide the input values by.");
-TF_DEFINE_string(input_layer, "Mul", "The name of the input node.");
-TF_DEFINE_string(output_layer, "softmax", "The name of the output node.");
-TF_DEFINE_bool(self_test, false, "Whether to run a sanity check on the results.");
-TF_DEFINE_string(root_dir, "", "The directory at the root of the data files.");
-
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
@@ -249,18 +222,115 @@ Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
return Status::OK();
}
+namespace {
+
+bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ string* dst) {
+ if (arg.Consume(flag) && arg.Consume("=")) {
+ *dst = arg.ToString();
+ return true;
+ }
+
+ return false;
+}
+
+bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ int32* dst) {
+ if (arg.Consume(flag) && arg.Consume("=")) {
+ char extra;
+ return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
+ }
+
+ return false;
+}
+
+bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ bool* dst) {
+ if (arg.Consume(flag)) {
+ if (arg.empty()) {
+ *dst = true;
+ return true;
+ }
+
+ if (arg == "=true") {
+ *dst = true;
+ return true;
+ } else if (arg == "=false") {
+ *dst = false;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
int main(int argc, char* argv[]) {
- // We need to call this to set up global state for TensorFlow.
- tensorflow::port::InitMain(argv[0], &argc, &argv);
- Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
- if (!s.ok()) {
- LOG(ERROR) << "Error parsing command line flags: " << s.ToString();
+ // These are the command-line flags the program can understand.
+ // They define where the graph and input data is located, and what kind of
+ // input the model expects. If you train your own model, or use something
+ // other than GoogLeNet you'll need to update these.
+ string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
+ string graph =
+ "tensorflow/examples/label_image/data/"
+ "tensorflow_inception_graph.pb";
+ string labels =
+ "tensorflow/examples/label_image/data/"
+ "imagenet_comp_graph_label_strings.txt";
+ int32 input_width = 299;
+ int32 input_height = 299;
+ int32 input_mean = 128;
+ int32 input_std = 128;
+
+ string input_layer = "Mul";
+ string output_layer = "softmax";
+ bool self_test = false;
+ string root_dir = "";
+
+ std::vector<char*> unknown_flags;
+ for (int i = 1; i < argc; ++i) {
+ if (string(argv[i]) == "--") {
+ while (i < argc) {
+ unknown_flags.push_back(argv[i]);
+ ++i;
+ }
+ break;
+ }
+
+ if (ParseStringFlag(argv[i], "--image", &image) ||
+ ParseStringFlag(argv[i], "--graph", &graph) ||
+ ParseStringFlag(argv[i], "--labels", &labels) ||
+ ParseInt32Flag(argv[i], "--input_width", &input_width) ||
+ ParseInt32Flag(argv[i], "--input_height", &input_height) ||
+ ParseInt32Flag(argv[i], "--input_mean", &input_mean) ||
+ ParseInt32Flag(argv[i], "--input_std", &input_std) ||
+ ParseStringFlag(argv[i], "--input_layer", &input_layer) ||
+ ParseStringFlag(argv[i], "--output_layer", &output_layer) ||
+ ParseBoolFlag(argv[i], "--self_test", &self_test) ||
+ ParseStringFlag(argv[i], "--root_dir", &root_dir)) {
+ continue;
+ }
+
+ fprintf(stderr, "Unknown flag: %s\n", argv[i]);
return -1;
}
+ // Passthrough any extra flags.
+ int dst = 1; // Skip argv[0]
+
+ for (char* f : unknown_flags) {
+ argv[dst++] = f;
+ }
+ argv[dst++] = nullptr;
+ argc = unknown_flags.size() + 1;
+
+ // We need to call this to set up global state for TensorFlow.
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
// First we load and initialize the model.
std::unique_ptr<tensorflow::Session> session;
- string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph);
+ string graph_path = tensorflow::io::JoinPath(root_dir, graph);
Status load_graph_status = LoadGraph(graph_path, &session);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
@@ -270,10 +340,10 @@ int main(int argc, char* argv[]) {
// Get the image from disk as a float array of numbers, resized and normalized
// to the specifications the main graph expects.
std::vector<Tensor> resized_tensors;
- string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image);
- Status read_tensor_status = ReadTensorFromImageFile(
- image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean,
- FLAGS_input_std, &resized_tensors);
+ string image_path = tensorflow::io::JoinPath(root_dir, image);
+ Status read_tensor_status =
+ ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
+ input_std, &resized_tensors);
if (!read_tensor_status.ok()) {
LOG(ERROR) << read_tensor_status;
return -1;
@@ -282,8 +352,8 @@ int main(int argc, char* argv[]) {
// Actually run the image through the model.
std::vector<Tensor> outputs;
- Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}},
- {FLAGS_output_layer}, {}, &outputs);
+ Status run_status = session->Run({{input_layer, resized_tensor}},
+ {output_layer}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return -1;
@@ -292,7 +362,7 @@ int main(int argc, char* argv[]) {
// This is for automated testing to make sure we get the expected result with
// the default settings. We know that label 866 (military uniform) should be
// the top label for the Admiral Hopper image.
- if (FLAGS_self_test) {
+ if (self_test) {
bool expected_matches;
Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
if (!check_status.ok()) {
@@ -306,7 +376,7 @@ int main(int argc, char* argv[]) {
}
// Do something interesting with the results we've generated.
- Status print_status = PrintTopLabels(outputs, FLAGS_labels);
+ Status print_status = PrintTopLabels(outputs, labels);
if (!print_status.ok()) {
LOG(ERROR) << "Running print failed: " << print_status;
return -1;