From 7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Jul 2017 14:34:18 -0700 Subject: Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag. PiperOrigin-RevId: 162271241 --- tensorflow/core/util/command_line_flags.cc | 147 ++++++++++++++++++++++------- 1 file changed, 113 insertions(+), 34 deletions(-) (limited to 'tensorflow/core/util/command_line_flags.cc') diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 8373eb1f9e..3efc703faf 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -25,10 +25,11 @@ namespace tensorflow { namespace { bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - string* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { - *dst = arg.ToString(); + *value_parsing_ok = hook(arg.ToString()); return true; } @@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int32* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + int32 parsed_int32; + if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int32); } return true; } @@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int64* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) { + int64 parsed_int64; + if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int64); } return true; } @@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag)) { if (arg.empty()) { - *dst = true; + *value_parsing_ok = hook(true); return true; } if (arg == "=true") { - *dst = true; + *value_parsing_ok = hook(true); return true; } else if (arg == "=false") { - *dst = false; + *value_parsing_ok = hook(false); return true; } else { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag @@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - float* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + float parsed_float; + if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_float); } return true; } @@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } // namespace Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) - : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {} + : name_(name), + type_(TYPE_INT32), + int32_hook_([dst](int32 value) { + *dst = value; + return true; + }), + int32_default_for_display_(*dst), + usage_text_(usage_text) {} Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) : name_(name), type_(TYPE_INT64), - int64_value_(dst), + int64_hook_([dst](int64 value) { + *dst = value; + return true; + }), + int64_default_for_display_(*dst), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_hook_([dst](float value) { + *dst = value; + return true; + }), + float_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, bool* dst, const string& usage_text) : name_(name), type_(TYPE_BOOL), - bool_value_(dst), + bool_hook_([dst](bool value) { + *dst = value; + return true; + }), + bool_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, string* dst, const string& usage_text) : name_(name), type_(TYPE_STRING), - string_value_(dst), + string_hook_([dst](string value) { + *dst = std::move(value); + return true; + }), + string_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const string& usage_text) +Flag::Flag(const char* name, std::function int32_hook, + int32 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT32), + int32_hook_(std::move(int32_hook)), + int32_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function int64_hook, + int64 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT64), + int64_hook_(std::move(int64_hook)), + int64_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function float_hook, + float default_value_for_display, const string& usage_text) : name_(name), type_(TYPE_FLOAT), - float_value_(dst), + float_hook_(std::move(float_hook)), + float_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function bool_hook, + bool default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_BOOL), + bool_hook_(std::move(bool_hook)), + bool_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function string_hook, + string default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_STRING), + string_hook_(std::move(string_hook)), + string_default_for_display_(std::move(default_value_for_display)), usage_text_(usage_text) {} bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; - if (type_ == TYPE_INT) { - result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + if (type_ == TYPE_INT32) { + result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok); } else if (type_ == TYPE_INT64) { - result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok); + result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok); } else if (type_ == TYPE_BOOL) { - result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok); } else if (type_ == TYPE_STRING) { - result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok); } else if (type_ == TYPE_FLOAT) { - result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); + result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok); } return result; } @@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { for (const Flag& flag : flag_list) { const char* type_name = ""; string flag_string; - if (flag.type_ == Flag::TYPE_INT) { + if (flag.type_ == Flag::TYPE_INT32) { type_name = "int32"; - flag_string = - strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_); + flag_string = strings::Printf("--%s=%d", flag.name_.c_str(), + flag.int32_default_for_display_); } else if (flag.type_ == Flag::TYPE_INT64) { type_name = "int64"; - flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(), - static_cast(*flag.int64_value_)); + flag_string = strings::Printf( + "--%s=%lld", flag.name_.c_str(), + static_cast(flag.int64_default_for_display_)); } else if (flag.type_ == Flag::TYPE_BOOL) { type_name = "bool"; - flag_string = strings::Printf("--%s=%s", flag.name_.c_str(), - *flag.bool_value_ ? "true" : "false"); + flag_string = + strings::Printf("--%s=%s", flag.name_.c_str(), + flag.bool_default_for_display_ ? "true" : "false"); } else if (flag.type_ == Flag::TYPE_STRING) { type_name = "string"; flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), - flag.string_value_->c_str()); + flag.string_default_for_display_.c_str()); } else if (flag.type_ == Flag::TYPE_FLOAT) { type_name = "float"; - flag_string = - strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_); + flag_string = strings::Printf("--%s=%f", flag.name_.c_str(), + flag.float_default_for_display_); } strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), type_name, flag.usage_text_.c_str()); -- cgit v1.2.3