diff options
author | Pete Warden <pete@petewarden.com> | 2016-01-16 14:59:03 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2016-01-18 12:21:50 -0800 |
commit | 8ffdf4a33941e8b1f259c3834d4376cd1cdf3790 (patch) | |
tree | f723f7af3e6ea63b1a363d40eb340b47f08897b0 /tensorflow/examples/label_image | |
parent | 48a95edc43fa420342c8002b3643afbbbffb0065 (diff) |
Moved the flag parsing into a separate module.
Tidying up the label_image example so that the argument handling is taken care
of outside of the main flow of the code, to make understanding it easier.
Change: 112333770
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r-- | tensorflow/examples/label_image/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/examples/label_image/command_line_flags.cc | 135 | ||||
-rw-r--r-- | tensorflow/examples/label_image/command_line_flags.h | 59 | ||||
-rw-r--r-- | tensorflow/examples/label_image/command_line_flags_test.cc | 79 | ||||
-rw-r--r-- | tensorflow/examples/label_image/main.cc | 99 |
5 files changed, 308 insertions, 81 deletions
diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD index c13b530e85..1498a21c53 100644 --- a/tensorflow/examples/label_image/BUILD +++ b/tensorflow/examples/label_image/BUILD @@ -12,7 +12,10 @@ exports_files(["LICENSE"]) cc_binary( name = "label_image", - srcs = ["main.cc"], + srcs = [ + "command_line_flags.cc", + "main.cc", + ], linkopts = ["-lm"], deps = [ "//tensorflow/cc:cc_ops", @@ -20,6 +23,18 @@ cc_binary( ], ) +cc_test( + name = "command_line_flags_test", + srcs = [ + "command_line_flags.cc", + "command_line_flags_test.cc", + ], + deps = [ + "//tensorflow/core:tensorflow", + "//tensorflow/core/platform/default/build_config:test_main", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/examples/label_image/command_line_flags.cc b/tensorflow/examples/label_image/command_line_flags.cc new file mode 100644 index 0000000000..9806ccb624 --- /dev/null +++ b/tensorflow/examples/label_image/command_line_flags.cc @@ -0,0 +1,135 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/label_image/command_line_flags.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::string; + +namespace { + +bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + string* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + *dst = arg.ToString(); + return true; + } + + return false; +} + +bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + tensorflow::int32* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + char extra; + if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + +bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + bool* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag)) { + if (arg.empty()) { + *dst = true; + return true; + } + + if (arg == "=true") { + *dst = true; + return true; + } else if (arg == "=false") { + *dst = false; + return true; + } else { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + return true; + } + } + + return false; +} + +} // namespace + +Flag::Flag(const char* name, tensorflow::int32* dst) + : name_(name), type_(TYPE_INT), int_value_(dst) {} + +Flag::Flag(const char* name, bool* dst) + : name_(name), type_(TYPE_BOOL), bool_value_(dst) {} + +Flag::Flag(const char* name, string* dst) + : name_(name), type_(TYPE_STRING), string_value_(dst) {} + +bool Flag::Parse(string arg, bool* value_parsing_ok) const { + bool result = false; + if (type_ == TYPE_INT) { + result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + } else if (type_ == TYPE_BOOL) { + result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + } else if (type_ == TYPE_STRING) { + result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + } + return result; +} + +bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) { + bool result = true; + std::vector<char*> unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + // Passthrough any extra flags. + int dst = 1; // Skip argv[0] + for (char* f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result; +} diff --git a/tensorflow/examples/label_image/command_line_flags.h b/tensorflow/examples/label_image/command_line_flags.h new file mode 100644 index 0000000000..dcc8f3fdde --- /dev/null +++ b/tensorflow/examples/label_image/command_line_flags.h @@ -0,0 +1,59 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H +#define TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H + +#include <vector> +#include "tensorflow/core/platform/types.h" + +// This is a simple command-line argument parsing module to help us handle +// example parameters. The recommended way of using it is with local variables +// and an initializer list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// string some_name = "something"; +// bool parsed_values_ok = ParseFlags(&argc, argv, { +// Flag("some_int", &some_int), +// Flag("some_switch", &some_switch), +// Flag("some_name", &some_name)}); +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +class Flag { + public: + Flag(const char* name, tensorflow::int32* dst1); + Flag(const char* name, bool* dst); + Flag(const char* name, tensorflow::string* dst); + + bool Parse(tensorflow::string arg, bool* value_parsing_ok) const; + + private: + tensorflow::string name_; + enum { TYPE_INT, TYPE_BOOL, TYPE_STRING } type_; + int* int_value_; + bool* bool_value_; + tensorflow::string* string_value_; +}; + +bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list); + +#endif // TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H diff --git a/tensorflow/examples/label_image/command_line_flags_test.cc b/tensorflow/examples/label_image/command_line_flags_test.cc new file mode 100644 index 0000000000..dd9cda5bf6 --- /dev/null +++ b/tensorflow/examples/label_image/command_line_flags_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/label_image/command_line_flags.h" +#include <gtest/gtest.h> + +namespace { +// The returned array is only valid for the lifetime of the input vector. +// We're using const casting because we need to pass in an argv-style array of +// char* pointers for the API, even though we know they won't be altered. +std::vector<char*> CharPointerVectorFromStrings( + const std::vector<tensorflow::string>& strings) { + std::vector<char*> result; + for (const tensorflow::string& string : strings) { + result.push_back(const_cast<char*>(string.c_str())); + } + return result; +} +} + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int = 10; + bool some_switch = false; + string some_name = "something"; + int argc = 4; + std::vector<tensorflow::string> argv_strings = { + "program_name", "--some_int=20", "--some_switch", + "--some_name=somethingelse"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int), + Flag("some_switch", &some_switch), + Flag("some_name", &some_name)}); + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + std::vector<tensorflow::string> argv_strings = {"program_name", + "--some_int=notanumber"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + std::vector<tensorflow::string> argv_strings = {"program_name", + "--some_switch=notabool"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 9bbc21a575..4460fefa3f 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/tensor.h" +#include "tensorflow/examples/label_image/command_line_flags.h" // These are all common classes it's handy to reference with no namespace. using tensorflow::Tensor; @@ -223,50 +224,6 @@ Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected, return Status::OK(); } -namespace { - -bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - string* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { - *dst = arg.ToString(); - return true; - } - - return false; -} - -bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - int32* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { - char extra; - return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); - } - - return false; -} - -bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst) { - if (arg.Consume(flag)) { - if (arg.empty()) { - *dst = true; - return true; - } - - if (arg == "=true") { - *dst = true; - return true; - } else if (arg == "=false") { - *dst = false; - return true; - } - } - - return false; -} - -} // namespace - int main(int argc, char* argv[]) { // These are the command-line flags the program can understand. // They define where the graph and input data is located, and what kind of @@ -283,51 +240,33 @@ int main(int argc, char* argv[]) { int32 input_height = 299; int32 input_mean = 128; int32 input_std = 128; - string input_layer = "Mul"; string output_layer = "softmax"; bool self_test = false; string root_dir = ""; - - std::vector<char*> unknown_flags; - for (int i = 1; i < argc; ++i) { - if (string(argv[i]) == "--") { - while (i < argc) { - unknown_flags.push_back(argv[i]); - ++i; - } - break; - } - - if (ParseStringFlag(argv[i], "--image", &image) || - ParseStringFlag(argv[i], "--graph", &graph) || - ParseStringFlag(argv[i], "--labels", &labels) || - ParseInt32Flag(argv[i], "--input_width", &input_width) || - ParseInt32Flag(argv[i], "--input_height", &input_height) || - ParseInt32Flag(argv[i], "--input_mean", &input_mean) || - ParseInt32Flag(argv[i], "--input_std", &input_std) || - ParseStringFlag(argv[i], "--input_layer", &input_layer) || - ParseStringFlag(argv[i], "--output_layer", &output_layer) || - ParseBoolFlag(argv[i], "--self_test", &self_test) || - ParseStringFlag(argv[i], "--root_dir", &root_dir)) { - continue; - } - - fprintf(stderr, "Unknown flag: %s\n", argv[i]); + const bool parse_result = + ParseFlags(&argc, argv, {Flag("image", &image), // + Flag("graph", &graph), // + Flag("labels", &labels), // + Flag("input_width", &input_width), // + Flag("input_height", &input_height), // + Flag("input_mean", &input_mean), // + Flag("input_std", &input_std), // + Flag("input_layer", &input_layer), // + Flag("output_layer", &output_layer), // + Flag("self_test", &self_test), // + Flag("root_dir", &root_dir)}); + if (!parse_result) { + LOG(ERROR) << "Error parsing command-line flags."; return -1; } - // Passthrough any extra flags. - int dst = 1; // Skip argv[0] - - for (char* f : unknown_flags) { - argv[dst++] = f; - } - argv[dst++] = nullptr; - argc = unknown_flags.size() + 1; - // We need to call this to set up global state for TensorFlow. tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1]; + return -1; + } // First we load and initialize the model. std::unique_ptr<tensorflow::Session> session; |