diff options
author | 2017-07-17 14:34:18 -0700 | |
---|---|---|
committer | 2017-07-17 14:39:05 -0700 | |
commit | 7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 (patch) | |
tree | 16b1440865a9020411f73e51ac75080c2a24c729 /tensorflow/core | |
parent | 0c144afecef6800589d255dd990a9a88e9f94b23 (diff) |
Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag.
PiperOrigin-RevId: 162271241
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/util/command_line_flags.cc | 147 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags.h | 51 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags_test.cc | 163 |
3 files changed, 296 insertions, 65 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 8373eb1f9e..3efc703faf 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -25,10 +25,11 @@ namespace tensorflow { namespace { bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - string* dst, bool* value_parsing_ok) { + const std::function<bool(string)>& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { - *dst = arg.ToString(); + *value_parsing_ok = hook(arg.ToString()); return true; } @@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int32* dst, bool* value_parsing_ok) { + const std::function<bool(int32)>& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + int32 parsed_int32; + if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int32); } return true; } @@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int64* dst, bool* value_parsing_ok) { + const std::function<bool(int64)>& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) { + int64 parsed_int64; + if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int64); } return true; } @@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst, bool* value_parsing_ok) { + const std::function<bool(bool)>& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag)) { if (arg.empty()) { - *dst = true; + *value_parsing_ok = hook(true); return true; } if (arg == "=true") { - *dst = true; + *value_parsing_ok = hook(true); return true; } else if (arg == "=false") { - *dst = false; + *value_parsing_ok = hook(false); return true; } else { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag @@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - float* dst, bool* value_parsing_ok) { + const std::function<bool(float)>& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + float parsed_float; + if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_float); } return true; } @@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } // namespace Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) - : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {} + : name_(name), + type_(TYPE_INT32), + int32_hook_([dst](int32 value) { + *dst = value; + return true; + }), + int32_default_for_display_(*dst), + usage_text_(usage_text) {} Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) : name_(name), type_(TYPE_INT64), - int64_value_(dst), + int64_hook_([dst](int64 value) { + *dst = value; + return true; + }), + int64_default_for_display_(*dst), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_hook_([dst](float value) { + *dst = value; + return true; + }), + float_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, bool* dst, const string& usage_text) : name_(name), type_(TYPE_BOOL), - bool_value_(dst), + bool_hook_([dst](bool value) { + *dst = value; + return true; + }), + bool_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, string* dst, const string& usage_text) : name_(name), type_(TYPE_STRING), - string_value_(dst), + string_hook_([dst](string value) { + *dst = std::move(value); + return true; + }), + string_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const string& usage_text) +Flag::Flag(const char* name, std::function<bool(int32)> int32_hook, + int32 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT32), + int32_hook_(std::move(int32_hook)), + int32_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function<bool(int64)> int64_hook, + int64 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT64), + int64_hook_(std::move(int64_hook)), + int64_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function<bool(float)> float_hook, + float default_value_for_display, const string& usage_text) : name_(name), type_(TYPE_FLOAT), - float_value_(dst), + float_hook_(std::move(float_hook)), + float_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function<bool(bool)> bool_hook, + bool default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_BOOL), + bool_hook_(std::move(bool_hook)), + bool_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function<bool(string)> string_hook, + string default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_STRING), + string_hook_(std::move(string_hook)), + string_default_for_display_(std::move(default_value_for_display)), usage_text_(usage_text) {} bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; - if (type_ == TYPE_INT) { - result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + if (type_ == TYPE_INT32) { + result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok); } else if (type_ == TYPE_INT64) { - result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok); + result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok); } else if (type_ == TYPE_BOOL) { - result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok); } else if (type_ == TYPE_STRING) { - result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok); } else if (type_ == TYPE_FLOAT) { - result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); + result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok); } return result; } @@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { for (const Flag& flag : flag_list) { const char* type_name = ""; string flag_string; - if (flag.type_ == Flag::TYPE_INT) { + if (flag.type_ == Flag::TYPE_INT32) { type_name = "int32"; - flag_string = - strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_); + flag_string = strings::Printf("--%s=%d", flag.name_.c_str(), + flag.int32_default_for_display_); } else if (flag.type_ == Flag::TYPE_INT64) { type_name = "int64"; - flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(), - static_cast<long long>(*flag.int64_value_)); + flag_string = strings::Printf( + "--%s=%lld", flag.name_.c_str(), + static_cast<long long>(flag.int64_default_for_display_)); } else if (flag.type_ == Flag::TYPE_BOOL) { type_name = "bool"; - flag_string = strings::Printf("--%s=%s", flag.name_.c_str(), - *flag.bool_value_ ? "true" : "false"); + flag_string = + strings::Printf("--%s=%s", flag.name_.c_str(), + flag.bool_default_for_display_ ? "true" : "false"); } else if (flag.type_ == Flag::TYPE_STRING) { type_name = "string"; flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), - flag.string_value_->c_str()); + flag.string_default_for_display_.c_str()); } else if (flag.type_ == Flag::TYPE_FLOAT) { type_name = "float"; - flag_string = - strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_); + flag_string = strings::Printf("--%s=%f", flag.name_.c_str(), + flag.float_default_for_display_); } strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), type_name, flag.usage_text_.c_str()); diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index f349df16fd..121c7063c9 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H #define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#include <functional> #include <string> #include <vector> #include "tensorflow/core/platform/types.h" @@ -61,24 +62,58 @@ namespace tensorflow { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32* dst1, const string& usage_text); - Flag(const char* name, int64* dst1, const string& usage_text); + Flag(const char* name, int32* dst, const string& usage_text); + Flag(const char* name, int64* dst, const string& usage_text); Flag(const char* name, bool* dst, const string& usage_text); Flag(const char* name, string* dst, const string& usage_text); Flag(const char* name, float* dst, const string& usage_text); + // These constructors invoke a hook on a match instead of writing to a + // specific memory location. The hook may return false to signal a malformed + // or illegal value, which will then fail the command line parse. + // + // "default_value_for_display" is shown as the default value of this flag in + // Flags::Usage(). + Flag(const char* name, std::function<bool(int32)> int32_hook, + int32 default_value_for_display, const string& usage_text); + Flag(const char* name, std::function<bool(int64)> int64_hook, + int64 default_value_for_display, const string& usage_text); + Flag(const char* name, std::function<bool(float)> float_hook, + float default_value_for_display, const string& usage_text); + Flag(const char* name, std::function<bool(bool)> bool_hook, + bool default_value_for_display, const string& usage_text); + Flag(const char* name, std::function<bool(string)> string_hook, + string default_value_for_display, const string& usage_text); + private: friend class Flags; bool Parse(string arg, bool* value_parsing_ok) const; string name_; - enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_; - int* int_value_; - int64* int64_value_; - bool* bool_value_; - string* string_value_; - float* float_value_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::function<bool(int32)> int32_hook_; + int32 int32_default_for_display_; + + std::function<bool(int64)> int64_hook_; + int64 int64_default_for_display_; + + std::function<bool(float)> float_hook_; + float float_default_for_display_; + + std::function<bool(bool)> bool_hook_; + bool bool_default_for_display_; + + std::function<bool(string)> string_hook_; + string string_default_for_display_; + string usage_text_; }; diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index c86a70ec9d..6139c8e7bc 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -36,32 +36,85 @@ std::vector<char *> CharPointerVectorFromStrings( } // namespace TEST(CommandLineFlagsTest, BasicUsage) { - int some_int = 10; - int64 some_int64 = 21474836470; // max int32 is 2147483647 - bool some_switch = false; - string some_name = "something"; - float some_float = -23.23f; - int argc = 6; + int some_int32_set_directly = 10; + int some_int32_set_via_hook = 20; + int64 some_int64_set_directly = 21474836470; // max int32 is 2147483647 + int64 some_int64_set_via_hook = 21474836479; // max int32 is 2147483647 + bool some_switch_set_directly = false; + bool some_switch_set_via_hook = true; + string some_name_set_directly = "something_a"; + string some_name_set_via_hook = "something_b"; + float some_float_set_directly = -23.23f; + float some_float_set_via_hook = -25.23f; std::vector<string> argv_strings = {"program_name", - "--some_int=20", - "--some_int64=214748364700", - "--some_switch", - "--some_name=somethingelse", - "--some_float=42.0"}; + "--some_int32_set_directly=20", + "--some_int32_set_via_hook=50", + "--some_int64_set_directly=214748364700", + "--some_int64_set_via_hook=214748364710", + "--some_switch_set_directly", + "--some_switch_set_via_hook=false", + "--some_name_set_directly=somethingelse", + "--some_name_set_via_hook=anythingelse", + "--some_float_set_directly=42.0", + "--some_float_set_via_hook=43.0"}; + int argc = argv_strings.size(); std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); - bool parsed_ok = - Flags::Parse(&argc, argv_array.data(), - {Flag("some_int", &some_int, "some int"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name"), - Flag("some_float", &some_float, "some float")}); + bool parsed_ok = Flags::Parse( + &argc, argv_array.data(), + { + Flag("some_int32_set_directly", &some_int32_set_directly, + "some int32 set directly"), + Flag("some_int32_set_via_hook", + [&](int32 value) { + some_int32_set_via_hook = value; + return true; + }, + some_int32_set_via_hook, "some int32 set via hook"), + Flag("some_int64_set_directly", &some_int64_set_directly, + "some int64 set directly"), + Flag("some_int64_set_via_hook", + [&](int64 value) { + some_int64_set_via_hook = value; + return true; + }, + some_int64_set_via_hook, "some int64 set via hook"), + Flag("some_switch_set_directly", &some_switch_set_directly, + "some switch set directly"), + Flag("some_switch_set_via_hook", + [&](bool value) { + some_switch_set_via_hook = value; + return true; + }, + some_switch_set_via_hook, "some switch set via hook"), + Flag("some_name_set_directly", &some_name_set_directly, + "some name set directly"), + Flag("some_name_set_via_hook", + [&](string value) { + some_name_set_via_hook = std::move(value); + return true; + }, + some_name_set_via_hook, "some name set via hook"), + Flag("some_float_set_directly", &some_float_set_directly, + "some float set directly"), + Flag("some_float_set_via_hook", + [&](float value) { + some_float_set_via_hook = value; + return true; + }, + some_float_set_via_hook, "some float set via hook"), + }); + EXPECT_EQ(true, parsed_ok); - EXPECT_EQ(20, some_int); - EXPECT_EQ(214748364700, some_int64); - EXPECT_EQ(true, some_switch); - EXPECT_EQ("somethingelse", some_name); - EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(20, some_int32_set_directly); + EXPECT_EQ(50, some_int32_set_via_hook); + EXPECT_EQ(214748364700, some_int64_set_directly); + EXPECT_EQ(214748364710, some_int64_set_via_hook); + EXPECT_EQ(true, some_switch_set_directly); + EXPECT_EQ(false, some_switch_set_via_hook); + EXPECT_EQ("somethingelse", some_name_set_directly); + EXPECT_EQ("anythingelse", some_name_set_via_hook); + EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f); + EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f); EXPECT_EQ(argc, 1); } @@ -107,6 +160,70 @@ TEST(CommandLineFlagsTest, BadFloatValue) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, FailedInt32Hook) { + int argc = 2; + std::vector<string> argv_strings = {"program_name", "--some_int32=200"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_int32", [](int32 value) { return false; }, 30, + "some int32")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedInt64Hook) { + int argc = 2; + std::vector<string> argv_strings = {"program_name", "--some_int64=200"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_int64", [](int64 value) { return false; }, 30, + "some int64")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedFloatHook) { + int argc = 2; + std::vector<string> argv_strings = {"program_name", "--some_float=200.0"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_float", [](float value) { return false; }, 30.0f, + "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedBoolHook) { + int argc = 2; + std::vector<string> argv_strings = {"program_name", "--some_switch=true"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_switch", [](bool value) { return false; }, false, + "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedStringHook) { + int argc = 2; + std::vector<string> argv_strings = {"program_name", "--some_name=true"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = Flags::Parse( + &argc, argv_array.data(), + {Flag("some_name", [](string value) { return false; }, "", "some name")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + // Return whether str==pat, but allowing any whitespace in pat // to match zero or more whitespace characters in str. static bool MatchWithAnyWhitespace(const string &str, const string &pat) { |