From 20c02ffe0e6b8af4902487f852e15daa16caa523 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Oct 2016 15:41:38 -0800 Subject: Modify tensorflow command line flag parsing: - Allow help text to be specified for each flag - Add a flag "--help" to print the help text - Rename ParseFlags to Flags::Parse(); new routine Flags::Usage() returns a usage message. - Change uses to new format In some cases reorder with InitMain(), which should be called after flag parsing. Change: 136212902 --- .../contrib/ffmpeg/default/ffmpeg_lib_test.cc | 19 ++++- tensorflow/contrib/pi_examples/camera/camera.cc | 35 ++++---- .../contrib/pi_examples/label_image/label_image.cc | 41 ++++++---- .../contrib/tfprof/tools/tfprof/tfprof_main.cc | 58 ++++++++----- .../util/convert_graphdef_memmapped_format.cc | 26 +++--- .../rpc/grpc_tensorflow_server.cc | 32 +++++--- .../distributed_runtime/rpc/grpc_testlib_server.cc | 47 ++++++----- tensorflow/core/util/command_line_flags.cc | 70 +++++++++++++--- tensorflow/core/util/command_line_flags.h | 49 ++++++++--- tensorflow/core/util/command_line_flags_test.cc | 94 +++++++++++++++++----- tensorflow/examples/label_image/main.cc | 34 ++++---- tensorflow/tools/benchmark/benchmark_model.cc | 33 ++++---- .../tools/graph_transforms/fold_constants_tool.cc | 19 ++--- 13 files changed, 376 insertions(+), 181 deletions(-) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index f4cfa0bbc6..f374283c07 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -42,10 +42,17 @@ const char kTestMp3Filename[] = mutex mu; bool should_ffmpeg_be_installed GUARDED_BY(mu) = false; -void ParseTestFlags(int* argc, char** argv) { +string ParseTestFlags(int* argc, char** argv) { mutex_lock l(mu); - CHECK(ParseFlags(argc, argv, {Flag("should_ffmpeg_be_installed", - &should_ffmpeg_be_installed)})); + vector flag_list = {Flag("should_ffmpeg_be_installed", + &should_ffmpeg_be_installed, + "indicates that ffmpeg should be installed")}; + string usage = Flags::Usage(argv[0], flag_list); + if (!Flags::Parse(argc, argv, flag_list)) { + LOG(ERROR) << "\n" << usage; + exit(2); + } + return usage; } TEST(FfmpegLibTest, TestUninstalled) { @@ -132,7 +139,11 @@ TEST(FfmpegLibTest, TestRoundTripWav) { } // namespace tensorflow int main(int argc, char **argv) { - tensorflow::ffmpeg::ParseTestFlags(&argc, argv); + tensorflow::string usage = tensorflow::ffmpeg::ParseTestFlags(&argc, argv); testing::InitGoogleTest(&argc, argv); + if (argc != 1) { + LOG(ERROR) << usage; + return 2; + } return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/pi_examples/camera/camera.cc b/tensorflow/contrib/pi_examples/camera/camera.cc index 9bba110a52..cb20661662 100644 --- a/tensorflow/contrib/pi_examples/camera/camera.cc +++ b/tensorflow/contrib/pi_examples/camera/camera.cc @@ -412,21 +412,26 @@ int main(int argc, char** argv) { int32 video_height = 480; int print_threshold = 50; string root_dir = ""; - const bool parse_result = tensorflow::ParseFlags( - &argc, argv, {Flag("graph", &graph), // - Flag("labels", &labels_file_name), // - Flag("input_width", &input_width), // - Flag("input_height", &input_height), // - Flag("input_mean", &input_mean), // - Flag("input_std", &input_std), // - Flag("input_layer", &input_layer), // - Flag("output_layer", &output_layer), // - Flag("video_width", &video_width), // - Flag("video_height", &video_height), // - Flag("print_threshold", &print_threshold), // - Flag("root_dir", &root_dir)}); - if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + std::vector flag_list = { + Flag("graph", &graph, "graph file name"), + Flag("labels", &labels_file_name, "labels file name"), + Flag("input_width", &input_width, "image input width"), + Flag("input_height", &input_height, "image input height"), + Flag("input_mean", &input_mean, "transformed mean of input pixels"), + Flag("input_std", &input_std, "transformed std dev of input pixels"), + Flag("input_layer", &input_layer, "input layer name"), + Flag("output_layer", &output_layer, "output layer name"), + Flag("video_width", &video_width, "video width expected from device"), + Flag("video_height", &video_height, "video height expected from device"), + Flag("print_threshold", &print_threshold, + "print labels with scoe exceeding this"), + Flag("root_dir", &root_dir, + "interpret graph file name relative to this directory")}; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + + if (!parse_result || argc != 1) { + LOG(ERROR) << "\n" << usage; return -1; } diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc index 70f32f2199..ab19398ef2 100644 --- a/tensorflow/contrib/pi_examples/label_image/label_image.cc +++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc @@ -23,9 +23,10 @@ limitations under the License. // // Full build instructions are at tensorflow/contrib/pi_examples/README.md. -#include #include #include +#include +#include #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -297,7 +298,8 @@ int main(int argc, char* argv[]) { // They define where the graph and input data is located, and what kind of // input the model expects. If you train your own model, or use something // other than GoogLeNet you'll need to update these. - string image = "tensorflow/contrib/pi_examples/label_image/data/" + string image = + "tensorflow/contrib/pi_examples/label_image/data/" "grace_hopper.jpg"; string graph = "tensorflow/contrib/pi_examples/label_image/data/" @@ -313,27 +315,32 @@ int main(int argc, char* argv[]) { string output_layer = "softmax"; bool self_test = false; string root_dir = ""; - const bool parse_result = tensorflow::ParseFlags( - &argc, argv, {Flag("image", &image), // - Flag("graph", &graph), // - Flag("labels", &labels), // - Flag("input_width", &input_width), // - Flag("input_height", &input_height), // - Flag("input_mean", &input_mean), // - Flag("input_std", &input_std), // - Flag("input_layer", &input_layer), // - Flag("output_layer", &output_layer), // - Flag("self_test", &self_test), // - Flag("root_dir", &root_dir)}); + vector tensorflow::Flag > flag_list = { + Flag("image", &image, "image to be processed"), + Flag("graph", &graph, "graph to be executed"), + Flag("labels", &labels, "name of file containing labels"), + Flag("input_width", &input_width, "resize image to this width in pixels"), + Flag("input_height", &input_height, + "resize image to this height in pixels"), + Flag("input_mean", &input_mean, "scale pixel values to this mean"), + Flag("input_std", &input_std, "scale pixel values to this std deviation"), + Flag("input_layer", &input_layer, "name of input layer"), + Flag("output_layer", &output_layer, "name of output layer"), + Flag("self_test", &self_test, "run a self test"), + Flag("root_dir", &root_dir, + "interpret image and graph file names relative to this directory"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + LOG(ERROR) << "\n" << usage; return -1; } // We need to call this to set up global state for TensorFlow. - tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::port::InitMain(usage, &argc, &argv); if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1]; + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return -1; } diff --git a/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc b/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc index d9080242d6..38b1588d72 100644 --- a/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc +++ b/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc @@ -88,27 +88,43 @@ int main(int argc, char** argv) { fprintf(stderr, "%s\n", argv[i]); } - CHECK(tensorflow::ParseFlags( - &argc, argv, - {tensorflow::Flag("graph_path", &FLAGS_graph_path), - tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path), - tensorflow::Flag("op_log_path", &FLAGS_op_log_path), - tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path), - tensorflow::Flag("max_depth", &FLAGS_max_depth), - tensorflow::Flag("min_bytes", &FLAGS_min_bytes), - tensorflow::Flag("min_micros", &FLAGS_min_micros), - tensorflow::Flag("min_params", &FLAGS_min_params), - tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops), - tensorflow::Flag("device_regexes", &FLAGS_device_regexes), - tensorflow::Flag("order_by", &FLAGS_order_by), - tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes), - tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes), - tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes), - tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes), - tensorflow::Flag("account_displayed_op_only", - &FLAGS_account_displayed_op_only), - tensorflow::Flag("select", &FLAGS_select), - tensorflow::Flag("dump_to_file", &FLAGS_dump_to_file)})); + std::vector flag_list = { + tensorflow::Flag("graph_path", &FLAGS_graph_path, + "GraphDef proto text file name"), + tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path, + "RunMetadata proto binary file name"), + tensorflow::Flag("op_log_path", &FLAGS_op_log_path, + "tensorflow::tfprof::OpLog proto binary file name"), + tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path, + "TensorFlow Checkpoint file name"), + tensorflow::Flag("max_depth", &FLAGS_max_depth, "max depth"), + tensorflow::Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"), + tensorflow::Flag("min_micros", &FLAGS_min_micros, "min micros"), + tensorflow::Flag("min_params", &FLAGS_min_params, "min params"), + tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), + tensorflow::Flag("device_regexes", &FLAGS_device_regexes, + "device regexes"), + tensorflow::Flag("order_by", &FLAGS_order_by, "order by"), + tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes, + "start name regexes"), + tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes, + "trim name regexes"), + tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes, + "show name regexes"), + tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes, + "hide name regexes"), + tensorflow::Flag("account_displayed_op_only", + &FLAGS_account_displayed_op_only, + "account displayed op only"), + tensorflow::Flag("select", &FLAGS_select, "select"), + tensorflow::Flag("dump_to_file", &FLAGS_dump_to_file, "dump to file"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_ok) { + printf("%s", usage.c_str()); + return (2); + } tensorflow::port::InitMain(argv[0], &argc, &argv); fprintf(stderr, "%s\n", FLAGS_graph_path.c_str()); diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc index 6287d3eb7e..29b124e2a8 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc @@ -28,6 +28,8 @@ limitations under the License. // min_conversion_size_bytes - tensors with fewer than this many bytes of data // will not be converted to ImmutableConst format, and kept in the graph. +#include + #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -40,23 +42,23 @@ int ParseFlagsAndConvertGraph(int argc, char* argv[]) { string in_graph = ""; string out_graph = ""; int min_conversion_tensor_size = 10000; - const bool parse_result = ParseFlags( - &argc, argv, - {// input graph - Flag("in_graph", &in_graph), - // output graph - Flag("out_graph", &out_graph), - // constants with tensors that have less than this number elements won't - // be converted into ImmutableConst (be memmapped). - Flag("min_conversion_tensor_size", &min_conversion_tensor_size)}); + std::vector flag_list = { + Flag("in_graph", &in_graph, "input graph"), + Flag("out_graph", &out_graph, "output graph"), + Flag("min_conversion_tensor_size", &min_conversion_tensor_size, + "constants with tensors that have less than this number elements " + "won't be converted into ImmutableConst (be memmapped)"), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); // We need to call this to set up global state for TensorFlow. - port::InitMain(argv[0], &argc, &argv); + port::InitMain(usage.c_str(), &argc, &argv); if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + LOG(ERROR) << "\n" << usage; return -1; } if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1]; + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return -1; } if (in_graph.empty()) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 3416f99919..1c7bb4375c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "grpc++/grpc++.h" #include "grpc++/security/credentials.h" @@ -36,17 +37,10 @@ limitations under the License. namespace tensorflow { namespace { -Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { +Status FillServerDef(const string& cluster_spec, const string& job_name, + int task_index, ServerDef* options) { options->set_protocol("grpc"); - string cluster_spec; - int task_index = 0; - const bool parse_result = ParseFlags( - &argc, argv, {Flag("cluster_spec", &cluster_spec), // - Flag("job_name", options->mutable_job_name()), // - Flag("task_id", &task_index)}); - if (!parse_result) { - return errors::InvalidArgument("Error parsing command-line flags"); - } + options->set_job_name(job_name); options->set_task_index(task_index); size_t my_num_tasks = 0; @@ -101,9 +95,25 @@ void Usage(char* const argv_0) { } int main(int argc, char* argv[]) { + tensorflow::string cluster_spec; + tensorflow::string job_name; + int task_index = 0; + std::vector flag_list = { + tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"), + tensorflow::Flag("job_name", &job_name, "job name"), + tensorflow::Flag("task_id", &task_index, "task id"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(argv[0], &argc, &argv); + if (!parse_result || argc != 1) { + std::cerr << usage << std::endl; + Usage(argv[0]); + return -1; + } tensorflow::ServerDef server_def; - tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def); + tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name, + task_index, &server_def); if (!s.ok()) { std::cerr << "ERROR: " << s.error_message() << std::endl; Usage(argv[0]); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index f31687068b..953cf933d5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "grpc++/grpc++.h" #include "grpc++/security/credentials.h" #include "grpc++/server_builder.h" @@ -32,25 +34,13 @@ limitations under the License. namespace tensorflow { namespace { -Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { +Status FillServerDef(const string& job_spec, const string& job_name, + int num_cpus, int num_gpus, int task_index, + ServerDef* options) { options->set_protocol("grpc"); - string job_spec; - int num_cpus = 1; - int num_gpus = 0; - int task_index = 0; - const bool parse_result = - ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), // - Flag("tf_job", options->mutable_job_name()), // - Flag("tf_task", &task_index), // - Flag("num_cpus", &num_cpus), // - Flag("num_gpus", &num_gpus)}); - + options->set_job_name(job_name); options->set_task_index(task_index); - if (!parse_result) { - return errors::InvalidArgument("Error parsing command-line flags"); - } - uint32 my_tasks_per_replica = 0; for (const string& job_str : str_util::Split(job_spec, ',')) { JobDef* job_def = options->mutable_cluster()->add_job(); @@ -85,13 +75,32 @@ Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { } // namespace tensorflow int main(int argc, char* argv[]) { + tensorflow::string job_spec; + tensorflow::string job_name; + int num_cpus = 1; + int num_gpus = 0; + int task_index = 0; + std::vector flag_list = { + tensorflow::Flag("tf_jobs", &job_spec, "job specification"), + tensorflow::Flag("tf_job", &job_name, "job name"), + tensorflow::Flag("tf_task", &task_index, "task index"), + tensorflow::Flag("num_cpus", &num_cpus, "number of CPUs"), + tensorflow::Flag("num_gpus", &num_gpus, "number of GPUs"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(argv[0], &argc, &argv); + if (!parse_result || argc != 1) { + LOG(ERROR) << usage; + return -1; + } tensorflow::ServerDef def; - tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &def); - + tensorflow::Status s = tensorflow::FillServerDef(job_spec, job_name, num_cpus, + num_gpus, task_index, &def); if (!s.ok()) { - LOG(ERROR) << "Could not parse flags: " << s.error_message(); + LOG(ERROR) << "Could not parse job spec: " << s.error_message() << "\n" + << usage; return -1; } diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 2048126338..03eb076f30 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/util/command_line_flags.h" +#include +#include + #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { @@ -91,17 +95,26 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } // namespace -Flag::Flag(const char* name, tensorflow::int32* dst) - : name_(name), type_(TYPE_INT), int_value_(dst) {} +Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) + : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, tensorflow::int64* dst) - : name_(name), type_(TYPE_INT64), int64_value_(dst) {} +Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) + : name_(name), + type_(TYPE_INT64), + int64_value_(dst), + usage_text_(usage_text) {} -Flag::Flag(const char* name, bool* dst) - : name_(name), type_(TYPE_BOOL), bool_value_(dst) {} +Flag::Flag(const char* name, bool* dst, const string& usage_text) + : name_(name), + type_(TYPE_BOOL), + bool_value_(dst), + usage_text_(usage_text) {} -Flag::Flag(const char* name, string* dst) - : name_(name), type_(TYPE_STRING), string_value_(dst) {} +Flag::Flag(const char* name, string* dst, const string& usage_text) + : name_(name), + type_(TYPE_STRING), + string_value_(dst), + usage_text_(usage_text) {} bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; @@ -117,7 +130,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { return result; } -bool ParseFlags(int* argc, char** argv, const std::vector& flag_list) { +/*static*/ bool Flags::Parse(int* argc, char** argv, + const std::vector& flag_list) { bool result = true; std::vector unknown_flags; for (int i = 1; i < *argc; ++i) { @@ -151,7 +165,41 @@ bool ParseFlags(int* argc, char** argv, const std::vector& flag_list) { } argv[dst++] = nullptr; *argc = unknown_flags.size() + 1; - return result; + return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); +} + +/*static*/ string Flags::Usage(const string& cmdline, + const std::vector& flag_list) { + string usage_text; + if (!flag_list.empty()) { + strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str()); + } else { + strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str()); + } + for (const Flag& flag : flag_list) { + const char* type_name = ""; + string flag_string; + if (flag.type_ == Flag::TYPE_INT) { + type_name = "int32"; + flag_string = + strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_); + } else if (flag.type_ == Flag::TYPE_INT64) { + type_name = "int64"; + flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(), + static_cast(*flag.int64_value_)); + } else if (flag.type_ == Flag::TYPE_BOOL) { + type_name = "bool"; + flag_string = strings::Printf("--%s=%s", flag.name_.c_str(), + *flag.bool_value_ ? "true" : "false"); + } else if (flag.type_ == Flag::TYPE_STRING) { + type_name = "string"; + flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), + flag.string_value_->c_str()); + } + strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), + type_name, flag.usage_text_.c_str()); + } + return usage_text; } } // namespace tensorflow diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 9297fb066d..2c77d7874f 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H #define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#include #include #include "tensorflow/core/platform/types.h" @@ -30,10 +31,19 @@ namespace tensorflow { // int some_int = 10; // bool some_switch = false; // string some_name = "something"; -// bool parsed_values_ok = ParseFlags(&argc, argv, { -// Flag("some_int", &some_int), -// Flag("some_switch", &some_switch), -// Flag("some_name", &some_name)}); +// std::vector flag_list = { +// Flag("some_int", &some_int, "an integer that affects X"), +// Flag("some_switch", &some_switch, "a bool that affects Y"), +// Flag("some_name", &some_name, "a string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } // // The argc and argv values are adjusted by the Parse function so all that // remains is the program name (at argv[0]) and any unknown arguments fill the @@ -46,25 +56,44 @@ namespace tensorflow { // NOTE: Unlike gflags-style libraries, this library is intended to be // used in the `main()` function of your binary. It does not handle // flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32* dst1); - Flag(const char* name, int64* dst1); - Flag(const char* name, bool* dst); - Flag(const char* name, string* dst); + Flag(const char* name, int32* dst1, const string& usage_text); + Flag(const char* name, int64* dst1, const string& usage_text); + Flag(const char* name, bool* dst, const string& usage_text); + Flag(const char* name, string* dst, const string& usage_text); + + private: + friend class Flags; bool Parse(string arg, bool* value_parsing_ok) const; - private: string name_; enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_; int* int_value_; int64* int64_value_; bool* bool_value_; string* string_value_; + string usage_text_; }; -bool ParseFlags(int* argc, char** argv, const std::vector& flag_list); +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int* argc, char** argv, const std::vector& flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static string Usage(const string& cmdline, + const std::vector& flag_list); +}; } // namespace tensorflow diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index bc38fff8fd..b002e35899 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -13,19 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/util/command_line_flags.h" +#include +#include + #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { // The returned array is only valid for the lifetime of the input vector. // We're using const casting because we need to pass in an argv-style array of // char* pointers for the API, even though we know they won't be altered. -std::vector CharPointerVectorFromStrings( - const std::vector& strings) { - std::vector result; - for (const tensorflow::string& string : strings) { - result.push_back(const_cast(string.c_str())); +std::vector CharPointerVectorFromStrings( + const std::vector &strings) { + std::vector result; + for (const string &string : strings) { + result.push_back(const_cast(string.c_str())); } return result; } @@ -35,16 +38,18 @@ TEST(CommandLineFlagsTest, BasicUsage) { int some_int = 10; int64 some_int64 = 21474836470; // max int32 is 2147483647 bool some_switch = false; - tensorflow::string some_name = "something"; + string some_name = "something"; int argc = 5; - std::vector argv_strings = { + std::vector argv_strings = { "program_name", "--some_int=20", "--some_int64=214748364700", "--some_switch", "--some_name=somethingelse"}; - std::vector argv_array = CharPointerVectorFromStrings(argv_strings); - bool parsed_ok = ParseFlags( - &argc, argv_array.data(), - {Flag("some_int", &some_int), Flag("some_int64", &some_int64), - Flag("some_switch", &some_switch), Flag("some_name", &some_name)}); + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_int", &some_int, "some int"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name")}); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int); EXPECT_EQ(214748364700, some_int64); @@ -56,11 +61,10 @@ TEST(CommandLineFlagsTest, BasicUsage) { TEST(CommandLineFlagsTest, BadIntValue) { int some_int = 10; int argc = 2; - std::vector argv_strings = {"program_name", - "--some_int=notanumber"}; - std::vector argv_array = CharPointerVectorFromStrings(argv_strings); - bool parsed_ok = - ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)}); + std::vector argv_strings = {"program_name", "--some_int=notanumber"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = Flags::Parse(&argc, argv_array.data(), + {Flag("some_int", &some_int, "some int")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(10, some_int); @@ -70,15 +74,61 @@ TEST(CommandLineFlagsTest, BadIntValue) { TEST(CommandLineFlagsTest, BadBoolValue) { bool some_switch = false; int argc = 2; - std::vector argv_strings = {"program_name", - "--some_switch=notabool"}; - std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + std::vector argv_strings = {"program_name", "--some_switch=notabool"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); bool parsed_ok = - ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)}); + Flags::Parse(&argc, argv_array.data(), + {Flag("some_switch", &some_switch, "some switch")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(false, some_switch); EXPECT_EQ(argc, 1); } +// Return whether str==pat, but allowing any whitespace in pat +// to match zero or more whitespace characters in str. +static bool MatchWithAnyWhitespace(const string &str, const string &pat) { + bool matching = true; + int pat_i = 0; + for (int str_i = 0; str_i != str.size() && matching; str_i++) { + if (isspace(str[str_i])) { + matching = (pat_i != pat.size() && isspace(pat[pat_i])); + } else { + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); + } + } + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + return (matching && pat_i == pat.size()); +} + +TEST(CommandLineFlagsTest, UsageString) { + int some_int = 10; + int64 some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + string some_name = "something"; + const string tool_name = "some_tool_name"; + string usage = Flags::Usage(tool_name + "", + {Flag("some_int", &some_int, "some int"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name")}); + // Match the usage message, being sloppy about whitespace. + const char *expected_usage = + " usage: some_tool_name \n" + "Flags:\n" + "--some_int=10 int32 some int\n" + "--some_int64=21474836470 int64 some int64\n" + "--some_switch=false bool some switch\n" + "--some_name=\"something\" string some name\n"; + ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true); + + // Again but with no flags. + usage = Flags::Usage(tool_name, {}); + ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true); +} } // namespace tensorflow diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 7faf1cef61..3a927ca14b 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -32,6 +32,7 @@ limitations under the License. // The googlenet_graph.pb file included by default is created from Inception. #include +#include #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/image_ops.h" @@ -246,27 +247,32 @@ int main(int argc, char* argv[]) { string output_layer = "softmax"; bool self_test = false; string root_dir = ""; - const bool parse_result = tensorflow::ParseFlags( - &argc, argv, {Flag("image", &image), // - Flag("graph", &graph), // - Flag("labels", &labels), // - Flag("input_width", &input_width), // - Flag("input_height", &input_height), // - Flag("input_mean", &input_mean), // - Flag("input_std", &input_std), // - Flag("input_layer", &input_layer), // - Flag("output_layer", &output_layer), // - Flag("self_test", &self_test), // - Flag("root_dir", &root_dir)}); + std::vector flag_list = { + Flag("image", &image, "image to be processed"), + Flag("graph", &graph, "graph to be executed"), + Flag("labels", &labels, "name of file containing labels"), + Flag("input_width", &input_width, "resize image to this width in pixels"), + Flag("input_height", &input_height, + "resize image to this height in pixels"), + Flag("input_mean", &input_mean, "scale pixel values to this mean"), + Flag("input_std", &input_std, "scale pixel values to this std deviation"), + Flag("input_layer", &input_layer, "name of input layer"), + Flag("output_layer", &output_layer, "name of output layer"), + Flag("self_test", &self_test, "run a self test"), + Flag("root_dir", &root_dir, + "interpret image and graph file names relative to this directory"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + LOG(ERROR) << usage; return -1; } // We need to call this to set up global state for TensorFlow. tensorflow::port::InitMain(argv[0], &argc, &argv); if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1]; + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return -1; } diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 5d824df6cf..3279967aaa 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -173,29 +173,30 @@ int Main(int argc, char** argv) { string output_prefix = ""; bool show_sizes = false; - const bool parse_result = ParseFlags( - &argc, argv, { - Flag("graph", &graph), // - Flag("input_layer", &input_layer), // - Flag("input_layer_shape", &input_layer_shape), // - Flag("input_layer_type", &input_layer_type), // - Flag("output_layer", &output_layer), // - Flag("num_runs", &num_runs), // - Flag("run_delay", &run_delay), // - Flag("num_threads", &num_threads), // - Flag("benchmark_name", &benchmark_name), // - Flag("output_prefix", &output_prefix), // - Flag("show_sizes", &show_sizes), // - }); + std::vector flag_list = { + Flag("graph", &graph, "graph file name"), + Flag("input_layer", &input_layer, "input layer name"), + Flag("input_layer_shape", &input_layer_shape, "input layer shape"), + Flag("input_layer_type", &input_layer_type, "input layer type"), + Flag("output_layer", &output_layer, "output layer name"), + Flag("num_runs", &num_runs, "number of runs"), + Flag("run_delay", &run_delay, "delay between runs in seconds"), + Flag("num_threads", &num_threads, "number of threads"), + Flag("benchmark_name", &benchmark_name, "benchmark name"), + Flag("output_prefix", &output_prefix, "benchmark output prefix"), + Flag("show_sizes", &show_sizes, "whether to show sizes"), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + LOG(ERROR) << usage; return -1; } ::tensorflow::port::InitMain(argv[0], &argc, &argv); if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1]; + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return -1; } diff --git a/tensorflow/tools/graph_transforms/fold_constants_tool.cc b/tensorflow/tools/graph_transforms/fold_constants_tool.cc index ae4880b7e8..bfcbdf6b14 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_tool.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_tool.cc @@ -46,21 +46,22 @@ int ParseFlagsAndConvertGraph(int argc, char* argv[]) { string out_graph = ""; string inputs_string = ""; string outputs_string = ""; - const bool parse_result = - ParseFlags(&argc, argv, { - Flag("in_graph", &in_graph), // - Flag("out_graph", &out_graph), // - Flag("inputs", &inputs_string), // - Flag("outputs", &outputs_string), // - }); + std::vector flag_list = { + Flag("in_graph", &in_graph, "input graph file name"), + Flag("out_graph", &out_graph, "output graph file name"), + Flag("inputs", &inputs_string, "inputs"), + Flag("outputs", &outputs_string, "outputs"), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); // We need to call this to set up global state for TensorFlow. port::InitMain(argv[0], &argc, &argv); if (!parse_result) { - LOG(ERROR) << "Error parsing command-line flags."; + LOG(ERROR) << usage; return -1; } if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1]; + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return -1; } if (in_graph.empty()) { -- cgit v1.2.3