aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/command_line_flags.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-14 15:41:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 16:48:42 -0700
commit20c02ffe0e6b8af4902487f852e15daa16caa523 (patch)
treebd2b821e5021f714bf5db3a92bbb269ea715002d /tensorflow/core/util/command_line_flags.cc
parent5c1821be018d4a626efd0a9cee7844aaa8c69366 (diff)
Modify tensorflow command line flag parsing:
- Allow help text to be specified for each flag - Add a flag "--help" to print the help text - Rename ParseFlags to Flags::Parse(); new routine Flags::Usage() returns a usage message. - Change uses to new format In some cases reorder with InitMain(), which should be called after flag parsing. Change: 136212902
Diffstat (limited to 'tensorflow/core/util/command_line_flags.cc')
-rw-r--r--tensorflow/core/util/command_line_flags.cc70
1 files changed, 59 insertions, 11 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 2048126338..03eb076f30 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/command_line_flags.h"
+#include <string>
+#include <vector>
+
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
@@ -91,17 +95,26 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
-Flag::Flag(const char* name, tensorflow::int32* dst)
- : name_(name), type_(TYPE_INT), int_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
+ : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
-Flag::Flag(const char* name, tensorflow::int64* dst)
- : name_(name), type_(TYPE_INT64), int64_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, bool* dst)
- : name_(name), type_(TYPE_BOOL), bool_value_(dst) {}
+Flag::Flag(const char* name, bool* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, string* dst)
- : name_(name), type_(TYPE_STRING), string_value_(dst) {}
+Flag::Flag(const char* name, string* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_value_(dst),
+ usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
@@ -117,7 +130,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
return result;
}
-bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
+/*static*/ bool Flags::Parse(int* argc, char** argv,
+ const std::vector<Flag>& flag_list) {
bool result = true;
std::vector<char*> unknown_flags;
for (int i = 1; i < *argc; ++i) {
@@ -151,7 +165,41 @@ bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
}
argv[dst++] = nullptr;
*argc = unknown_flags.size() + 1;
- return result;
+ return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
+}
+
+/*static*/ string Flags::Usage(const string& cmdline,
+ const std::vector<Flag>& flag_list) {
+ string usage_text;
+ if (!flag_list.empty()) {
+ strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
+ } else {
+ strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
+ }
+ for (const Flag& flag : flag_list) {
+ const char* type_name = "";
+ string flag_string;
+ if (flag.type_ == Flag::TYPE_INT) {
+ type_name = "int32";
+ flag_string =
+ strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ } else if (flag.type_ == Flag::TYPE_INT64) {
+ type_name = "int64";
+ flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(*flag.int64_value_));
+ } else if (flag.type_ == Flag::TYPE_BOOL) {
+ type_name = "bool";
+ flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
+ *flag.bool_value_ ? "true" : "false");
+ } else if (flag.type_ == Flag::TYPE_STRING) {
+ type_name = "string";
+ flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
+ flag.string_value_->c_str());
+ }
+ strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
+ type_name, flag.usage_text_.c_str());
+ }
+ return usage_text;
}
} // namespace tensorflow