diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-01-06 12:54:03 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-06 12:54:03 -0800 |
commit | b947bcc154db62e2306e8bf3edb1a94868edbbca (patch) | |
tree | b3019f9edc459e53185cabe64ef27890fc1aebb8 /tensorflow/examples/label_image | |
parent | e26f0f34018b7adb519e989b59a0462b08a93ea8 (diff) |
TensorFlow: Remove use of command_line_flags library in our usage. Soon
to be deleted completely.
Change: 111521876
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r-- | tensorflow/examples/label_image/main.cc | 152 |
1 files changed, 111 insertions, 41 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index a5f14e4d6f..dd07e1e5cd 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -39,7 +39,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/core/command_line_flags.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -56,32 +55,6 @@ using tensorflow::Status; using tensorflow::string; using tensorflow::int32; -// These are the command-line flags the program can understand. -// They define where the graph and input data is located, and what kind of -// input the model expects. If you train your own model, or use something -// other than GoogLeNet you'll need to update these. -TF_DEFINE_string(image, - "tensorflow/examples/label_image/data/grace_hopper.jpg", - "The image to classify (JPEG or PNG)."); -TF_DEFINE_string(graph, - "tensorflow/examples/label_image/data/" - "tensorflow_inception_graph.pb", - "The location of the GraphDef file containing the protobuf" - " definition of the network."); -TF_DEFINE_string(labels, - "tensorflow/examples/label_image/data/" - "imagenet_comp_graph_label_strings.txt", - "A text file containing the labels of all the categories, one" - " per line."); -TF_DEFINE_int32(input_width, 299, "Width of the image the network expects."); -TF_DEFINE_int32(input_height, 299, "Height of the image the network expects."); -TF_DEFINE_int32(input_mean, 128, "How much to subtract from input values."); -TF_DEFINE_int32(input_std, 128, "What to divide the input values by."); -TF_DEFINE_string(input_layer, "Mul", "The name of the input node."); -TF_DEFINE_string(output_layer, "softmax", "The name of the output node."); -TF_DEFINE_bool(self_test, false, "Whether to run a sanity check on the results."); -TF_DEFINE_string(root_dir, "", "The directory at the root of the data files."); - // Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length // of the result is a multiple of 16, because our model expects that. @@ -249,18 +222,115 @@ Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected, return Status::OK(); } +namespace { + +bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + string* dst) { + if (arg.Consume(flag) && arg.Consume("=")) { + *dst = arg.ToString(); + return true; + } + + return false; +} + +bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + int32* dst) { + if (arg.Consume(flag) && arg.Consume("=")) { + char extra; + return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); + } + + return false; +} + +bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + bool* dst) { + if (arg.Consume(flag)) { + if (arg.empty()) { + *dst = true; + return true; + } + + if (arg == "=true") { + *dst = true; + return true; + } else if (arg == "=false") { + *dst = false; + return true; + } + } + + return false; +} + +} // namespace + int main(int argc, char* argv[]) { - // We need to call this to set up global state for TensorFlow. - tensorflow::port::InitMain(argv[0], &argc, &argv); - Status s = tensorflow::ParseCommandLineFlags(&argc, argv); - if (!s.ok()) { - LOG(ERROR) << "Error parsing command line flags: " << s.ToString(); + // These are the command-line flags the program can understand. + // They define where the graph and input data is located, and what kind of + // input the model expects. If you train your own model, or use something + // other than GoogLeNet you'll need to update these. + string image = "tensorflow/examples/label_image/data/grace_hopper.jpg"; + string graph = + "tensorflow/examples/label_image/data/" + "tensorflow_inception_graph.pb"; + string labels = + "tensorflow/examples/label_image/data/" + "imagenet_comp_graph_label_strings.txt"; + int32 input_width = 299; + int32 input_height = 299; + int32 input_mean = 128; + int32 input_std = 128; + + string input_layer = "Mul"; + string output_layer = "softmax"; + bool self_test = false; + string root_dir = ""; + + std::vector<char*> unknown_flags; + for (int i = 1; i < argc; ++i) { + if (string(argv[i]) == "--") { + while (i < argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + if (ParseStringFlag(argv[i], "--image", &image) || + ParseStringFlag(argv[i], "--graph", &graph) || + ParseStringFlag(argv[i], "--labels", &labels) || + ParseInt32Flag(argv[i], "--input_width", &input_width) || + ParseInt32Flag(argv[i], "--input_height", &input_height) || + ParseInt32Flag(argv[i], "--input_mean", &input_mean) || + ParseInt32Flag(argv[i], "--input_std", &input_std) || + ParseStringFlag(argv[i], "--input_layer", &input_layer) || + ParseStringFlag(argv[i], "--output_layer", &output_layer) || + ParseBoolFlag(argv[i], "--self_test", &self_test) || + ParseStringFlag(argv[i], "--root_dir", &root_dir)) { + continue; + } + + fprintf(stderr, "Unknown flag: %s\n", argv[i]); return -1; } + // Passthrough any extra flags. + int dst = 1; // Skip argv[0] + + for (char* f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + argc = unknown_flags.size() + 1; + + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(argv[0], &argc, &argv); + // First we load and initialize the model. std::unique_ptr<tensorflow::Session> session; - string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph); + string graph_path = tensorflow::io::JoinPath(root_dir, graph); Status load_graph_status = LoadGraph(graph_path, &session); if (!load_graph_status.ok()) { LOG(ERROR) << load_graph_status; @@ -270,10 +340,10 @@ int main(int argc, char* argv[]) { // Get the image from disk as a float array of numbers, resized and normalized // to the specifications the main graph expects. std::vector<Tensor> resized_tensors; - string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image); - Status read_tensor_status = ReadTensorFromImageFile( - image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean, - FLAGS_input_std, &resized_tensors); + string image_path = tensorflow::io::JoinPath(root_dir, image); + Status read_tensor_status = + ReadTensorFromImageFile(image_path, input_height, input_width, input_mean, + input_std, &resized_tensors); if (!read_tensor_status.ok()) { LOG(ERROR) << read_tensor_status; return -1; @@ -282,8 +352,8 @@ int main(int argc, char* argv[]) { // Actually run the image through the model. std::vector<Tensor> outputs; - Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}}, - {FLAGS_output_layer}, {}, &outputs); + Status run_status = session->Run({{input_layer, resized_tensor}}, + {output_layer}, {}, &outputs); if (!run_status.ok()) { LOG(ERROR) << "Running model failed: " << run_status; return -1; @@ -292,7 +362,7 @@ int main(int argc, char* argv[]) { // This is for automated testing to make sure we get the expected result with // the default settings. We know that label 866 (military uniform) should be // the top label for the Admiral Hopper image. - if (FLAGS_self_test) { + if (self_test) { bool expected_matches; Status check_status = CheckTopLabel(outputs, 866, &expected_matches); if (!check_status.ok()) { @@ -306,7 +376,7 @@ int main(int argc, char* argv[]) { } // Do something interesting with the results we've generated. - Status print_status = PrintTopLabels(outputs, FLAGS_labels); + Status print_status = PrintTopLabels(outputs, labels); if (!print_status.ok()) { LOG(ERROR) << "Running print failed: " << print_status; return -1; |