aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/command_line_flags.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-17 14:34:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 14:39:05 -0700
commit7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 (patch)
tree16b1440865a9020411f73e51ac75080c2a24c729 /tensorflow/core/util/command_line_flags.cc
parent0c144afecef6800589d255dd990a9a88e9f94b23 (diff)
Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag.
PiperOrigin-RevId: 162271241
Diffstat (limited to 'tensorflow/core/util/command_line_flags.cc')
-rw-r--r--tensorflow/core/util/command_line_flags.cc147
1 files changed, 113 insertions, 34 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 8373eb1f9e..3efc703faf 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -25,10 +25,11 @@ namespace tensorflow {
namespace {
bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- string* dst, bool* value_parsing_ok) {
+ const std::function<bool(string)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
- *dst = arg.ToString();
+ *value_parsing_ok = hook(arg.ToString());
return true;
}
@@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int32* dst, bool* value_parsing_ok) {
+ const std::function<bool(int32)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
+ int32 parsed_int32;
+ if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int32);
}
return true;
}
@@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int64* dst, bool* value_parsing_ok) {
+ const std::function<bool(int64)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) {
+ int64 parsed_int64;
+ if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int64);
}
return true;
}
@@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- bool* dst, bool* value_parsing_ok) {
+ const std::function<bool(bool)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag)) {
if (arg.empty()) {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
}
if (arg == "=true") {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
} else if (arg == "=false") {
- *dst = false;
+ *value_parsing_ok = hook(false);
return true;
} else {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
@@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- float* dst, bool* value_parsing_ok) {
+ const std::function<bool(float)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
+ float parsed_float;
+ if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_float);
}
return true;
}
@@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
- : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_([dst](int32 value) {
+ *dst = value;
+ return true;
+ }),
+ int32_default_for_display_(*dst),
+ usage_text_(usage_text) {}
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT64),
- int64_value_(dst),
+ int64_hook_([dst](int64 value) {
+ *dst = value;
+ return true;
+ }),
+ int64_default_for_display_(*dst),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, float* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_FLOAT),
+ float_hook_([dst](float value) {
+ *dst = value;
+ return true;
+ }),
+ float_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, bool* dst, const string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- bool_value_(dst),
+ bool_hook_([dst](bool value) {
+ *dst = value;
+ return true;
+ }),
+ bool_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, string* dst, const string& usage_text)
: name_(name),
type_(TYPE_STRING),
- string_value_(dst),
+ string_hook_([dst](string value) {
+ *dst = std::move(value);
+ return true;
+ }),
+ string_default_for_display_(*dst),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, float* dst, const string& usage_text)
+Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
+ int32 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_(std::move(int32_hook)),
+ int32_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
+ int64 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_hook_(std::move(int64_hook)),
+ int64_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(float)> float_hook,
+ float default_value_for_display, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- float_value_(dst),
+ float_hook_(std::move(float_hook)),
+ float_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
+ bool default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_hook_(std::move(bool_hook)),
+ bool_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(string)> string_hook,
+ string default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_hook_(std::move(string_hook)),
+ string_default_for_display_(std::move(default_value_for_display)),
usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
- if (type_ == TYPE_INT) {
- result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok);
+ if (type_ == TYPE_INT32) {
+ result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
} else if (type_ == TYPE_INT64) {
- result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok);
+ result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
} else if (type_ == TYPE_BOOL) {
- result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
+ result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
- result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
+ result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
} else if (type_ == TYPE_FLOAT) {
- result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
+ result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
}
return result;
}
@@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
for (const Flag& flag : flag_list) {
const char* type_name = "";
string flag_string;
- if (flag.type_ == Flag::TYPE_INT) {
+ if (flag.type_ == Flag::TYPE_INT32) {
type_name = "int32";
- flag_string =
- strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
+ flag.int32_default_for_display_);
} else if (flag.type_ == Flag::TYPE_INT64) {
type_name = "int64";
- flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
- static_cast<long long>(*flag.int64_value_));
+ flag_string = strings::Printf(
+ "--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(flag.int64_default_for_display_));
} else if (flag.type_ == Flag::TYPE_BOOL) {
type_name = "bool";
- flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
- *flag.bool_value_ ? "true" : "false");
+ flag_string =
+ strings::Printf("--%s=%s", flag.name_.c_str(),
+ flag.bool_default_for_display_ ? "true" : "false");
} else if (flag.type_ == Flag::TYPE_STRING) {
type_name = "string";
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
- flag.string_value_->c_str());
+ flag.string_default_for_display_.c_str());
} else if (flag.type_ == Flag::TYPE_FLOAT) {
type_name = "float";
- flag_string =
- strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
+ flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
+ flag.float_default_for_display_);
}
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
type_name, flag.usage_text_.c_str());