aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-14 15:41:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 16:48:42 -0700
commit20c02ffe0e6b8af4902487f852e15daa16caa523 (patch)
treebd2b821e5021f714bf5db3a92bbb269ea715002d
parent5c1821be018d4a626efd0a9cee7844aaa8c69366 (diff)
Modify tensorflow command line flag parsing:
- Allow help text to be specified for each flag - Add a flag "--help" to print the help text - Rename ParseFlags to Flags::Parse(); new routine Flags::Usage() returns a usage message. - Change uses to new format In some cases reorder with InitMain(), which should be called after flag parsing. Change: 136212902
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc19
-rw-r--r--tensorflow/contrib/pi_examples/camera/camera.cc35
-rw-r--r--tensorflow/contrib/pi_examples/label_image/label_image.cc41
-rw-r--r--tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc58
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format.cc26
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc32
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc47
-rw-r--r--tensorflow/core/util/command_line_flags.cc70
-rw-r--r--tensorflow/core/util/command_line_flags.h49
-rw-r--r--tensorflow/core/util/command_line_flags_test.cc94
-rw-r--r--tensorflow/examples/label_image/main.cc34
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc33
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_tool.cc19
13 files changed, 376 insertions, 181 deletions
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc
index f4cfa0bbc6..f374283c07 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc
@@ -42,10 +42,17 @@ const char kTestMp3Filename[] =
mutex mu;
bool should_ffmpeg_be_installed GUARDED_BY(mu) = false;
-void ParseTestFlags(int* argc, char** argv) {
+string ParseTestFlags(int* argc, char** argv) {
mutex_lock l(mu);
- CHECK(ParseFlags(argc, argv, {Flag("should_ffmpeg_be_installed",
- &should_ffmpeg_be_installed)}));
+ vector<Flag> flag_list = {Flag("should_ffmpeg_be_installed",
+ &should_ffmpeg_be_installed,
+ "indicates that ffmpeg should be installed")};
+ string usage = Flags::Usage(argv[0], flag_list);
+ if (!Flags::Parse(argc, argv, flag_list)) {
+ LOG(ERROR) << "\n" << usage;
+ exit(2);
+ }
+ return usage;
}
TEST(FfmpegLibTest, TestUninstalled) {
@@ -132,7 +139,11 @@ TEST(FfmpegLibTest, TestRoundTripWav) {
} // namespace tensorflow
int main(int argc, char **argv) {
- tensorflow::ffmpeg::ParseTestFlags(&argc, argv);
+ tensorflow::string usage = tensorflow::ffmpeg::ParseTestFlags(&argc, argv);
testing::InitGoogleTest(&argc, argv);
+ if (argc != 1) {
+ LOG(ERROR) << usage;
+ return 2;
+ }
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/pi_examples/camera/camera.cc b/tensorflow/contrib/pi_examples/camera/camera.cc
index 9bba110a52..cb20661662 100644
--- a/tensorflow/contrib/pi_examples/camera/camera.cc
+++ b/tensorflow/contrib/pi_examples/camera/camera.cc
@@ -412,21 +412,26 @@ int main(int argc, char** argv) {
int32 video_height = 480;
int print_threshold = 50;
string root_dir = "";
- const bool parse_result = tensorflow::ParseFlags(
- &argc, argv, {Flag("graph", &graph), //
- Flag("labels", &labels_file_name), //
- Flag("input_width", &input_width), //
- Flag("input_height", &input_height), //
- Flag("input_mean", &input_mean), //
- Flag("input_std", &input_std), //
- Flag("input_layer", &input_layer), //
- Flag("output_layer", &output_layer), //
- Flag("video_width", &video_width), //
- Flag("video_height", &video_height), //
- Flag("print_threshold", &print_threshold), //
- Flag("root_dir", &root_dir)});
- if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ std::vector<Flag> flag_list = {
+ Flag("graph", &graph, "graph file name"),
+ Flag("labels", &labels_file_name, "labels file name"),
+ Flag("input_width", &input_width, "image input width"),
+ Flag("input_height", &input_height, "image input height"),
+ Flag("input_mean", &input_mean, "transformed mean of input pixels"),
+ Flag("input_std", &input_std, "transformed std dev of input pixels"),
+ Flag("input_layer", &input_layer, "input layer name"),
+ Flag("output_layer", &output_layer, "output layer name"),
+ Flag("video_width", &video_width, "video width expected from device"),
+ Flag("video_height", &video_height, "video height expected from device"),
+ Flag("print_threshold", &print_threshold,
+ "print labels with scoe exceeding this"),
+ Flag("root_dir", &root_dir,
+ "interpret graph file name relative to this directory")};
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+
+ if (!parse_result || argc != 1) {
+ LOG(ERROR) << "\n" << usage;
return -1;
}
diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc
index 70f32f2199..ab19398ef2 100644
--- a/tensorflow/contrib/pi_examples/label_image/label_image.cc
+++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc
@@ -23,9 +23,10 @@ limitations under the License.
//
// Full build instructions are at tensorflow/contrib/pi_examples/README.md.
-#include <fstream>
#include <jpeglib.h>
#include <setjmp.h>
+#include <fstream>
+#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
@@ -297,7 +298,8 @@ int main(int argc, char* argv[]) {
// They define where the graph and input data is located, and what kind of
// input the model expects. If you train your own model, or use something
// other than GoogLeNet you'll need to update these.
- string image = "tensorflow/contrib/pi_examples/label_image/data/"
+ string image =
+ "tensorflow/contrib/pi_examples/label_image/data/"
"grace_hopper.jpg";
string graph =
"tensorflow/contrib/pi_examples/label_image/data/"
@@ -313,27 +315,32 @@ int main(int argc, char* argv[]) {
string output_layer = "softmax";
bool self_test = false;
string root_dir = "";
- const bool parse_result = tensorflow::ParseFlags(
- &argc, argv, {Flag("image", &image), //
- Flag("graph", &graph), //
- Flag("labels", &labels), //
- Flag("input_width", &input_width), //
- Flag("input_height", &input_height), //
- Flag("input_mean", &input_mean), //
- Flag("input_std", &input_std), //
- Flag("input_layer", &input_layer), //
- Flag("output_layer", &output_layer), //
- Flag("self_test", &self_test), //
- Flag("root_dir", &root_dir)});
+ vector tensorflow::Flag > flag_list = {
+ Flag("image", &image, "image to be processed"),
+ Flag("graph", &graph, "graph to be executed"),
+ Flag("labels", &labels, "name of file containing labels"),
+ Flag("input_width", &input_width, "resize image to this width in pixels"),
+ Flag("input_height", &input_height,
+ "resize image to this height in pixels"),
+ Flag("input_mean", &input_mean, "scale pixel values to this mean"),
+ Flag("input_std", &input_std, "scale pixel values to this std deviation"),
+ Flag("input_layer", &input_layer, "name of input layer"),
+ Flag("output_layer", &output_layer, "name of output layer"),
+ Flag("self_test", &self_test, "run a self test"),
+ Flag("root_dir", &root_dir,
+ "interpret image and graph file names relative to this directory"),
+ };
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ LOG(ERROR) << "\n" << usage;
return -1;
}
// We need to call this to set up global state for TensorFlow.
- tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::port::InitMain(usage, &argc, &argv);
if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1];
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
diff --git a/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc b/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc
index d9080242d6..38b1588d72 100644
--- a/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc
+++ b/tensorflow/contrib/tfprof/tools/tfprof/tfprof_main.cc
@@ -88,27 +88,43 @@ int main(int argc, char** argv) {
fprintf(stderr, "%s\n", argv[i]);
}
- CHECK(tensorflow::ParseFlags(
- &argc, argv,
- {tensorflow::Flag("graph_path", &FLAGS_graph_path),
- tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path),
- tensorflow::Flag("op_log_path", &FLAGS_op_log_path),
- tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path),
- tensorflow::Flag("max_depth", &FLAGS_max_depth),
- tensorflow::Flag("min_bytes", &FLAGS_min_bytes),
- tensorflow::Flag("min_micros", &FLAGS_min_micros),
- tensorflow::Flag("min_params", &FLAGS_min_params),
- tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops),
- tensorflow::Flag("device_regexes", &FLAGS_device_regexes),
- tensorflow::Flag("order_by", &FLAGS_order_by),
- tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes),
- tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes),
- tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes),
- tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes),
- tensorflow::Flag("account_displayed_op_only",
- &FLAGS_account_displayed_op_only),
- tensorflow::Flag("select", &FLAGS_select),
- tensorflow::Flag("dump_to_file", &FLAGS_dump_to_file)}));
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("graph_path", &FLAGS_graph_path,
+ "GraphDef proto text file name"),
+ tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path,
+ "RunMetadata proto binary file name"),
+ tensorflow::Flag("op_log_path", &FLAGS_op_log_path,
+ "tensorflow::tfprof::OpLog proto binary file name"),
+ tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path,
+ "TensorFlow Checkpoint file name"),
+ tensorflow::Flag("max_depth", &FLAGS_max_depth, "max depth"),
+ tensorflow::Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
+ tensorflow::Flag("min_micros", &FLAGS_min_micros, "min micros"),
+ tensorflow::Flag("min_params", &FLAGS_min_params, "min params"),
+ tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
+ tensorflow::Flag("device_regexes", &FLAGS_device_regexes,
+ "device regexes"),
+ tensorflow::Flag("order_by", &FLAGS_order_by, "order by"),
+ tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes,
+ "start name regexes"),
+ tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes,
+ "trim name regexes"),
+ tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes,
+ "show name regexes"),
+ tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes,
+ "hide name regexes"),
+ tensorflow::Flag("account_displayed_op_only",
+ &FLAGS_account_displayed_op_only,
+ "account displayed op only"),
+ tensorflow::Flag("select", &FLAGS_select, "select"),
+ tensorflow::Flag("dump_to_file", &FLAGS_dump_to_file, "dump to file"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_ok) {
+ printf("%s", usage.c_str());
+ return (2);
+ }
tensorflow::port::InitMain(argv[0], &argc, &argv);
fprintf(stderr, "%s\n", FLAGS_graph_path.c_str());
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
index 6287d3eb7e..29b124e2a8 100644
--- a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
@@ -28,6 +28,8 @@ limitations under the License.
// min_conversion_size_bytes - tensors with fewer than this many bytes of data
// will not be converted to ImmutableConst format, and kept in the graph.
+#include <vector>
+
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -40,23 +42,23 @@ int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
string in_graph = "";
string out_graph = "";
int min_conversion_tensor_size = 10000;
- const bool parse_result = ParseFlags(
- &argc, argv,
- {// input graph
- Flag("in_graph", &in_graph),
- // output graph
- Flag("out_graph", &out_graph),
- // constants with tensors that have less than this number elements won't
- // be converted into ImmutableConst (be memmapped).
- Flag("min_conversion_tensor_size", &min_conversion_tensor_size)});
+ std::vector<Flag> flag_list = {
+ Flag("in_graph", &in_graph, "input graph"),
+ Flag("out_graph", &out_graph, "output graph"),
+ Flag("min_conversion_tensor_size", &min_conversion_tensor_size,
+ "constants with tensors that have less than this number elements "
+ "won't be converted into ImmutableConst (be memmapped)"),
+ };
+ string usage = Flags::Usage(argv[0], flag_list);
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
- port::InitMain(argv[0], &argc, &argv);
+ port::InitMain(usage.c_str(), &argc, &argv);
if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ LOG(ERROR) << "\n" << usage;
return -1;
}
if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1];
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
if (in_graph.empty()) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index 3416f99919..1c7bb4375c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include <vector>
#include "grpc++/grpc++.h"
#include "grpc++/security/credentials.h"
@@ -36,17 +37,10 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+Status FillServerDef(const string& cluster_spec, const string& job_name,
+ int task_index, ServerDef* options) {
options->set_protocol("grpc");
- string cluster_spec;
- int task_index = 0;
- const bool parse_result = ParseFlags(
- &argc, argv, {Flag("cluster_spec", &cluster_spec), //
- Flag("job_name", options->mutable_job_name()), //
- Flag("task_id", &task_index)});
- if (!parse_result) {
- return errors::InvalidArgument("Error parsing command-line flags");
- }
+ options->set_job_name(job_name);
options->set_task_index(task_index);
size_t my_num_tasks = 0;
@@ -101,9 +95,25 @@ void Usage(char* const argv_0) {
}
int main(int argc, char* argv[]) {
+ tensorflow::string cluster_spec;
+ tensorflow::string job_name;
+ int task_index = 0;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"),
+ tensorflow::Flag("job_name", &job_name, "job name"),
+ tensorflow::Flag("task_id", &task_index, "task id"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result || argc != 1) {
+ std::cerr << usage << std::endl;
+ Usage(argv[0]);
+ return -1;
+ }
tensorflow::ServerDef server_def;
- tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def);
+ tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name,
+ task_index, &server_def);
if (!s.ok()) {
std::cerr << "ERROR: " << s.error_message() << std::endl;
Usage(argv[0]);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
index f31687068b..953cf933d5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <vector>
+
#include "grpc++/grpc++.h"
#include "grpc++/security/credentials.h"
#include "grpc++/server_builder.h"
@@ -32,25 +34,13 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+Status FillServerDef(const string& job_spec, const string& job_name,
+ int num_cpus, int num_gpus, int task_index,
+ ServerDef* options) {
options->set_protocol("grpc");
- string job_spec;
- int num_cpus = 1;
- int num_gpus = 0;
- int task_index = 0;
- const bool parse_result =
- ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), //
- Flag("tf_job", options->mutable_job_name()), //
- Flag("tf_task", &task_index), //
- Flag("num_cpus", &num_cpus), //
- Flag("num_gpus", &num_gpus)});
-
+ options->set_job_name(job_name);
options->set_task_index(task_index);
- if (!parse_result) {
- return errors::InvalidArgument("Error parsing command-line flags");
- }
-
uint32 my_tasks_per_replica = 0;
for (const string& job_str : str_util::Split(job_spec, ',')) {
JobDef* job_def = options->mutable_cluster()->add_job();
@@ -85,13 +75,32 @@ Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
} // namespace tensorflow
int main(int argc, char* argv[]) {
+ tensorflow::string job_spec;
+ tensorflow::string job_name;
+ int num_cpus = 1;
+ int num_gpus = 0;
+ int task_index = 0;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("tf_jobs", &job_spec, "job specification"),
+ tensorflow::Flag("tf_job", &job_name, "job name"),
+ tensorflow::Flag("tf_task", &task_index, "task index"),
+ tensorflow::Flag("num_cpus", &num_cpus, "number of CPUs"),
+ tensorflow::Flag("num_gpus", &num_gpus, "number of GPUs"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result || argc != 1) {
+ LOG(ERROR) << usage;
+ return -1;
+ }
tensorflow::ServerDef def;
- tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &def);
-
+ tensorflow::Status s = tensorflow::FillServerDef(job_spec, job_name, num_cpus,
+ num_gpus, task_index, &def);
if (!s.ok()) {
- LOG(ERROR) << "Could not parse flags: " << s.error_message();
+ LOG(ERROR) << "Could not parse job spec: " << s.error_message() << "\n"
+ << usage;
return -1;
}
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 2048126338..03eb076f30 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/command_line_flags.h"
+#include <string>
+#include <vector>
+
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
@@ -91,17 +95,26 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
-Flag::Flag(const char* name, tensorflow::int32* dst)
- : name_(name), type_(TYPE_INT), int_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
+ : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
-Flag::Flag(const char* name, tensorflow::int64* dst)
- : name_(name), type_(TYPE_INT64), int64_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, bool* dst)
- : name_(name), type_(TYPE_BOOL), bool_value_(dst) {}
+Flag::Flag(const char* name, bool* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, string* dst)
- : name_(name), type_(TYPE_STRING), string_value_(dst) {}
+Flag::Flag(const char* name, string* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_value_(dst),
+ usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
@@ -117,7 +130,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
return result;
}
-bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
+/*static*/ bool Flags::Parse(int* argc, char** argv,
+ const std::vector<Flag>& flag_list) {
bool result = true;
std::vector<char*> unknown_flags;
for (int i = 1; i < *argc; ++i) {
@@ -151,7 +165,41 @@ bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
}
argv[dst++] = nullptr;
*argc = unknown_flags.size() + 1;
- return result;
+ return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
+}
+
+/*static*/ string Flags::Usage(const string& cmdline,
+ const std::vector<Flag>& flag_list) {
+ string usage_text;
+ if (!flag_list.empty()) {
+ strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
+ } else {
+ strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
+ }
+ for (const Flag& flag : flag_list) {
+ const char* type_name = "";
+ string flag_string;
+ if (flag.type_ == Flag::TYPE_INT) {
+ type_name = "int32";
+ flag_string =
+ strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ } else if (flag.type_ == Flag::TYPE_INT64) {
+ type_name = "int64";
+ flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(*flag.int64_value_));
+ } else if (flag.type_ == Flag::TYPE_BOOL) {
+ type_name = "bool";
+ flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
+ *flag.bool_value_ ? "true" : "false");
+ } else if (flag.type_ == Flag::TYPE_STRING) {
+ type_name = "string";
+ flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
+ flag.string_value_->c_str());
+ }
+ strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
+ type_name, flag.usage_text_.c_str());
+ }
+ return usage_text;
}
} // namespace tensorflow
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 9297fb066d..2c77d7874f 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#include <string>
#include <vector>
#include "tensorflow/core/platform/types.h"
@@ -30,10 +31,19 @@ namespace tensorflow {
// int some_int = 10;
// bool some_switch = false;
// string some_name = "something";
-// bool parsed_values_ok = ParseFlags(&argc, argv, {
-// Flag("some_int", &some_int),
-// Flag("some_switch", &some_switch),
-// Flag("some_name", &some_name)});
+// std::vector<tensorFlow::Flag> flag_list = {
+// Flag("some_int", &some_int, "an integer that affects X"),
+// Flag("some_switch", &some_switch, "a bool that affects Y"),
+// Flag("some_name", &some_name, "a string that affects Z")
+// };
+// // Get usage message before ParseFlags() to capture default values.
+// string usage = Flag::Usage(argv[0], flag_list);
+// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list);
+//
+// tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+// if (argc != 1 || !parsed_values_ok) {
+// ...output usage and error message...
+// }
//
// The argc and argv values are adjusted by the Parse function so all that
// remains is the program name (at argv[0]) and any unknown arguments fill the
@@ -46,25 +56,44 @@ namespace tensorflow {
// NOTE: Unlike gflags-style libraries, this library is intended to be
// used in the `main()` function of your binary. It does not handle
// flag definitions that are scattered around the source code.
+
+// A description of a single command line flag, holding its name, type, usage
+// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32* dst1);
- Flag(const char* name, int64* dst1);
- Flag(const char* name, bool* dst);
- Flag(const char* name, string* dst);
+ Flag(const char* name, int32* dst1, const string& usage_text);
+ Flag(const char* name, int64* dst1, const string& usage_text);
+ Flag(const char* name, bool* dst, const string& usage_text);
+ Flag(const char* name, string* dst, const string& usage_text);
+
+ private:
+ friend class Flags;
bool Parse(string arg, bool* value_parsing_ok) const;
- private:
string name_;
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_;
int* int_value_;
int64* int64_value_;
bool* bool_value_;
string* string_value_;
+ string usage_text_;
};
-bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list);
+class Flags {
+ public:
+ // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
+ // instances matching flags in flaglist[]. Update the variables associated
+ // with matching flags, and remove the matching arguments from (*argc, argv).
+ // Return true iff all recognized flag values were parsed correctly, and the
+ // first remaining argument is not "--help".
+ static bool Parse(int* argc, char** argv, const std::vector<Flag>& flag_list);
+
+ // Return a usage message with command line cmdline, and the
+ // usage_text strings in flag_list[].
+ static string Usage(const string& cmdline,
+ const std::vector<Flag>& flag_list);
+};
} // namespace tensorflow
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index bc38fff8fd..b002e35899 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -13,19 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/command_line_flags.h"
+#include <ctype.h>
+#include <vector>
+
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
// The returned array is only valid for the lifetime of the input vector.
// We're using const casting because we need to pass in an argv-style array of
// char* pointers for the API, even though we know they won't be altered.
-std::vector<char*> CharPointerVectorFromStrings(
- const std::vector<tensorflow::string>& strings) {
- std::vector<char*> result;
- for (const tensorflow::string& string : strings) {
- result.push_back(const_cast<char*>(string.c_str()));
+std::vector<char *> CharPointerVectorFromStrings(
+ const std::vector<string> &strings) {
+ std::vector<char *> result;
+ for (const string &string : strings) {
+ result.push_back(const_cast<char *>(string.c_str()));
}
return result;
}
@@ -35,16 +38,18 @@ TEST(CommandLineFlagsTest, BasicUsage) {
int some_int = 10;
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
- tensorflow::string some_name = "something";
+ string some_name = "something";
int argc = 5;
- std::vector<tensorflow::string> argv_strings = {
+ std::vector<string> argv_strings = {
"program_name", "--some_int=20", "--some_int64=214748364700",
"--some_switch", "--some_name=somethingelse"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok = ParseFlags(
- &argc, argv_array.data(),
- {Flag("some_int", &some_int), Flag("some_int64", &some_int64),
- Flag("some_switch", &some_switch), Flag("some_name", &some_name)});
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int", &some_int, "some int"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name")});
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_EQ(214748364700, some_int64);
@@ -56,11 +61,10 @@ TEST(CommandLineFlagsTest, BasicUsage) {
TEST(CommandLineFlagsTest, BadIntValue) {
int some_int = 10;
int argc = 2;
- std::vector<tensorflow::string> argv_strings = {"program_name",
- "--some_int=notanumber"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok =
- ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)});
+ std::vector<string> argv_strings = {"program_name", "--some_int=notanumber"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok = Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(10, some_int);
@@ -70,15 +74,61 @@ TEST(CommandLineFlagsTest, BadIntValue) {
TEST(CommandLineFlagsTest, BadBoolValue) {
bool some_switch = false;
int argc = 2;
- std::vector<tensorflow::string> argv_strings = {"program_name",
- "--some_switch=notabool"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
+ std::vector<string> argv_strings = {"program_name", "--some_switch=notabool"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
bool parsed_ok =
- ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)});
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(false, some_switch);
EXPECT_EQ(argc, 1);
}
+// Return whether str==pat, but allowing any whitespace in pat
+// to match zero or more whitespace characters in str.
+static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
+ bool matching = true;
+ int pat_i = 0;
+ for (int str_i = 0; str_i != str.size() && matching; str_i++) {
+ if (isspace(str[str_i])) {
+ matching = (pat_i != pat.size() && isspace(pat[pat_i]));
+ } else {
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
+ }
+ }
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ return (matching && pat_i == pat.size());
+}
+
+TEST(CommandLineFlagsTest, UsageString) {
+ int some_int = 10;
+ int64 some_int64 = 21474836470; // max int32 is 2147483647
+ bool some_switch = false;
+ string some_name = "something";
+ const string tool_name = "some_tool_name";
+ string usage = Flags::Usage(tool_name + "<flags>",
+ {Flag("some_int", &some_int, "some int"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name")});
+ // Match the usage message, being sloppy about whitespace.
+ const char *expected_usage =
+ " usage: some_tool_name <flags>\n"
+ "Flags:\n"
+ "--some_int=10 int32 some int\n"
+ "--some_int64=21474836470 int64 some int64\n"
+ "--some_switch=false bool some switch\n"
+ "--some_name=\"something\" string some name\n";
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true);
+
+ // Again but with no flags.
+ usage = Flags::Usage(tool_name, {});
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true);
+}
} // namespace tensorflow
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index 7faf1cef61..3a927ca14b 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -32,6 +32,7 @@ limitations under the License.
// The googlenet_graph.pb file included by default is created from Inception.
#include <fstream>
+#include <vector>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
@@ -246,27 +247,32 @@ int main(int argc, char* argv[]) {
string output_layer = "softmax";
bool self_test = false;
string root_dir = "";
- const bool parse_result = tensorflow::ParseFlags(
- &argc, argv, {Flag("image", &image), //
- Flag("graph", &graph), //
- Flag("labels", &labels), //
- Flag("input_width", &input_width), //
- Flag("input_height", &input_height), //
- Flag("input_mean", &input_mean), //
- Flag("input_std", &input_std), //
- Flag("input_layer", &input_layer), //
- Flag("output_layer", &output_layer), //
- Flag("self_test", &self_test), //
- Flag("root_dir", &root_dir)});
+ std::vector<Flag> flag_list = {
+ Flag("image", &image, "image to be processed"),
+ Flag("graph", &graph, "graph to be executed"),
+ Flag("labels", &labels, "name of file containing labels"),
+ Flag("input_width", &input_width, "resize image to this width in pixels"),
+ Flag("input_height", &input_height,
+ "resize image to this height in pixels"),
+ Flag("input_mean", &input_mean, "scale pixel values to this mean"),
+ Flag("input_std", &input_std, "scale pixel values to this std deviation"),
+ Flag("input_layer", &input_layer, "name of input layer"),
+ Flag("output_layer", &output_layer, "name of output layer"),
+ Flag("self_test", &self_test, "run a self test"),
+ Flag("root_dir", &root_dir,
+ "interpret image and graph file names relative to this directory"),
+ };
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ LOG(ERROR) << usage;
return -1;
}
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1];
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index 5d824df6cf..3279967aaa 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -173,29 +173,30 @@ int Main(int argc, char** argv) {
string output_prefix = "";
bool show_sizes = false;
- const bool parse_result = ParseFlags(
- &argc, argv, {
- Flag("graph", &graph), //
- Flag("input_layer", &input_layer), //
- Flag("input_layer_shape", &input_layer_shape), //
- Flag("input_layer_type", &input_layer_type), //
- Flag("output_layer", &output_layer), //
- Flag("num_runs", &num_runs), //
- Flag("run_delay", &run_delay), //
- Flag("num_threads", &num_threads), //
- Flag("benchmark_name", &benchmark_name), //
- Flag("output_prefix", &output_prefix), //
- Flag("show_sizes", &show_sizes), //
- });
+ std::vector<Flag> flag_list = {
+ Flag("graph", &graph, "graph file name"),
+ Flag("input_layer", &input_layer, "input layer name"),
+ Flag("input_layer_shape", &input_layer_shape, "input layer shape"),
+ Flag("input_layer_type", &input_layer_type, "input layer type"),
+ Flag("output_layer", &output_layer, "output layer name"),
+ Flag("num_runs", &num_runs, "number of runs"),
+ Flag("run_delay", &run_delay, "delay between runs in seconds"),
+ Flag("num_threads", &num_threads, "number of threads"),
+ Flag("benchmark_name", &benchmark_name, "benchmark name"),
+ Flag("output_prefix", &output_prefix, "benchmark output prefix"),
+ Flag("show_sizes", &show_sizes, "whether to show sizes"),
+ };
+ string usage = Flags::Usage(argv[0], flag_list);
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ LOG(ERROR) << usage;
return -1;
}
::tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1];
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_tool.cc b/tensorflow/tools/graph_transforms/fold_constants_tool.cc
index ae4880b7e8..bfcbdf6b14 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_tool.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_tool.cc
@@ -46,21 +46,22 @@ int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
string out_graph = "";
string inputs_string = "";
string outputs_string = "";
- const bool parse_result =
- ParseFlags(&argc, argv, {
- Flag("in_graph", &in_graph), //
- Flag("out_graph", &out_graph), //
- Flag("inputs", &inputs_string), //
- Flag("outputs", &outputs_string), //
- });
+ std::vector<Flag> flag_list = {
+ Flag("in_graph", &in_graph, "input graph file name"),
+ Flag("out_graph", &out_graph, "output graph file name"),
+ Flag("inputs", &inputs_string, "inputs"),
+ Flag("outputs", &outputs_string, "outputs"),
+ };
+ string usage = Flags::Usage(argv[0], flag_list);
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
if (!parse_result) {
- LOG(ERROR) << "Error parsing command-line flags.";
+ LOG(ERROR) << usage;
return -1;
}
if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1];
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
if (in_graph.empty()) {