aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-17 14:34:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 14:39:05 -0700
commit7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 (patch)
tree16b1440865a9020411f73e51ac75080c2a24c729 /tensorflow/core
parent0c144afecef6800589d255dd990a9a88e9f94b23 (diff)
Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag.
PiperOrigin-RevId: 162271241
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/util/command_line_flags.cc147
-rw-r--r--tensorflow/core/util/command_line_flags.h51
-rw-r--r--tensorflow/core/util/command_line_flags_test.cc163
3 files changed, 296 insertions, 65 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 8373eb1f9e..3efc703faf 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -25,10 +25,11 @@ namespace tensorflow {
namespace {
bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- string* dst, bool* value_parsing_ok) {
+ const std::function<bool(string)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
- *dst = arg.ToString();
+ *value_parsing_ok = hook(arg.ToString());
return true;
}
@@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int32* dst, bool* value_parsing_ok) {
+ const std::function<bool(int32)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
+ int32 parsed_int32;
+ if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int32);
}
return true;
}
@@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int64* dst, bool* value_parsing_ok) {
+ const std::function<bool(int64)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) {
+ int64 parsed_int64;
+ if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int64);
}
return true;
}
@@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- bool* dst, bool* value_parsing_ok) {
+ const std::function<bool(bool)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag)) {
if (arg.empty()) {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
}
if (arg == "=true") {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
} else if (arg == "=false") {
- *dst = false;
+ *value_parsing_ok = hook(false);
return true;
} else {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
@@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- float* dst, bool* value_parsing_ok) {
+ const std::function<bool(float)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
+ float parsed_float;
+ if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_float);
}
return true;
}
@@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
- : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_([dst](int32 value) {
+ *dst = value;
+ return true;
+ }),
+ int32_default_for_display_(*dst),
+ usage_text_(usage_text) {}
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT64),
- int64_value_(dst),
+ int64_hook_([dst](int64 value) {
+ *dst = value;
+ return true;
+ }),
+ int64_default_for_display_(*dst),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, float* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_FLOAT),
+ float_hook_([dst](float value) {
+ *dst = value;
+ return true;
+ }),
+ float_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, bool* dst, const string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- bool_value_(dst),
+ bool_hook_([dst](bool value) {
+ *dst = value;
+ return true;
+ }),
+ bool_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, string* dst, const string& usage_text)
: name_(name),
type_(TYPE_STRING),
- string_value_(dst),
+ string_hook_([dst](string value) {
+ *dst = std::move(value);
+ return true;
+ }),
+ string_default_for_display_(*dst),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, float* dst, const string& usage_text)
+Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
+ int32 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_(std::move(int32_hook)),
+ int32_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
+ int64 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_hook_(std::move(int64_hook)),
+ int64_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(float)> float_hook,
+ float default_value_for_display, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- float_value_(dst),
+ float_hook_(std::move(float_hook)),
+ float_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
+ bool default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_hook_(std::move(bool_hook)),
+ bool_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(string)> string_hook,
+ string default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_hook_(std::move(string_hook)),
+ string_default_for_display_(std::move(default_value_for_display)),
usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
- if (type_ == TYPE_INT) {
- result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok);
+ if (type_ == TYPE_INT32) {
+ result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
} else if (type_ == TYPE_INT64) {
- result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok);
+ result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
} else if (type_ == TYPE_BOOL) {
- result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
+ result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
- result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
+ result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
} else if (type_ == TYPE_FLOAT) {
- result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
+ result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
}
return result;
}
@@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
for (const Flag& flag : flag_list) {
const char* type_name = "";
string flag_string;
- if (flag.type_ == Flag::TYPE_INT) {
+ if (flag.type_ == Flag::TYPE_INT32) {
type_name = "int32";
- flag_string =
- strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
+ flag.int32_default_for_display_);
} else if (flag.type_ == Flag::TYPE_INT64) {
type_name = "int64";
- flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
- static_cast<long long>(*flag.int64_value_));
+ flag_string = strings::Printf(
+ "--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(flag.int64_default_for_display_));
} else if (flag.type_ == Flag::TYPE_BOOL) {
type_name = "bool";
- flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
- *flag.bool_value_ ? "true" : "false");
+ flag_string =
+ strings::Printf("--%s=%s", flag.name_.c_str(),
+ flag.bool_default_for_display_ ? "true" : "false");
} else if (flag.type_ == Flag::TYPE_STRING) {
type_name = "string";
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
- flag.string_value_->c_str());
+ flag.string_default_for_display_.c_str());
} else if (flag.type_ == Flag::TYPE_FLOAT) {
type_name = "float";
- flag_string =
- strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
+ flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
+ flag.float_default_for_display_);
}
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
type_name, flag.usage_text_.c_str());
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index f349df16fd..121c7063c9 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#include <functional>
#include <string>
#include <vector>
#include "tensorflow/core/platform/types.h"
@@ -61,24 +62,58 @@ namespace tensorflow {
// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32* dst1, const string& usage_text);
- Flag(const char* name, int64* dst1, const string& usage_text);
+ Flag(const char* name, int32* dst, const string& usage_text);
+ Flag(const char* name, int64* dst, const string& usage_text);
Flag(const char* name, bool* dst, const string& usage_text);
Flag(const char* name, string* dst, const string& usage_text);
Flag(const char* name, float* dst, const string& usage_text);
+ // These constructors invoke a hook on a match instead of writing to a
+ // specific memory location. The hook may return false to signal a malformed
+ // or illegal value, which will then fail the command line parse.
+ //
+ // "default_value_for_display" is shown as the default value of this flag in
+ // Flags::Usage().
+ Flag(const char* name, std::function<bool(int32)> int32_hook,
+ int32 default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(int64)> int64_hook,
+ int64 default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(float)> float_hook,
+ float default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(bool)> bool_hook,
+ bool default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(string)> string_hook,
+ string default_value_for_display, const string& usage_text);
+
private:
friend class Flags;
bool Parse(string arg, bool* value_parsing_ok) const;
string name_;
- enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
- int* int_value_;
- int64* int64_value_;
- bool* bool_value_;
- string* string_value_;
- float* float_value_;
+ enum {
+ TYPE_INT32,
+ TYPE_INT64,
+ TYPE_BOOL,
+ TYPE_STRING,
+ TYPE_FLOAT,
+ } type_;
+
+ std::function<bool(int32)> int32_hook_;
+ int32 int32_default_for_display_;
+
+ std::function<bool(int64)> int64_hook_;
+ int64 int64_default_for_display_;
+
+ std::function<bool(float)> float_hook_;
+ float float_default_for_display_;
+
+ std::function<bool(bool)> bool_hook_;
+ bool bool_default_for_display_;
+
+ std::function<bool(string)> string_hook_;
+ string string_default_for_display_;
+
string usage_text_;
};
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index c86a70ec9d..6139c8e7bc 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -36,32 +36,85 @@ std::vector<char *> CharPointerVectorFromStrings(
} // namespace
TEST(CommandLineFlagsTest, BasicUsage) {
- int some_int = 10;
- int64 some_int64 = 21474836470; // max int32 is 2147483647
- bool some_switch = false;
- string some_name = "something";
- float some_float = -23.23f;
- int argc = 6;
+ int some_int32_set_directly = 10;
+ int some_int32_set_via_hook = 20;
+ int64 some_int64_set_directly = 21474836470; // max int32 is 2147483647
+ int64 some_int64_set_via_hook = 21474836479; // max int32 is 2147483647
+ bool some_switch_set_directly = false;
+ bool some_switch_set_via_hook = true;
+ string some_name_set_directly = "something_a";
+ string some_name_set_via_hook = "something_b";
+ float some_float_set_directly = -23.23f;
+ float some_float_set_via_hook = -25.23f;
std::vector<string> argv_strings = {"program_name",
- "--some_int=20",
- "--some_int64=214748364700",
- "--some_switch",
- "--some_name=somethingelse",
- "--some_float=42.0"};
+ "--some_int32_set_directly=20",
+ "--some_int32_set_via_hook=50",
+ "--some_int64_set_directly=214748364700",
+ "--some_int64_set_via_hook=214748364710",
+ "--some_switch_set_directly",
+ "--some_switch_set_via_hook=false",
+ "--some_name_set_directly=somethingelse",
+ "--some_name_set_via_hook=anythingelse",
+ "--some_float_set_directly=42.0",
+ "--some_float_set_via_hook=43.0"};
+ int argc = argv_strings.size();
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok =
- Flags::Parse(&argc, argv_array.data(),
- {Flag("some_int", &some_int, "some int"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name"),
- Flag("some_float", &some_float, "some float")});
+ bool parsed_ok = Flags::Parse(
+ &argc, argv_array.data(),
+ {
+ Flag("some_int32_set_directly", &some_int32_set_directly,
+ "some int32 set directly"),
+ Flag("some_int32_set_via_hook",
+ [&](int32 value) {
+ some_int32_set_via_hook = value;
+ return true;
+ },
+ some_int32_set_via_hook, "some int32 set via hook"),
+ Flag("some_int64_set_directly", &some_int64_set_directly,
+ "some int64 set directly"),
+ Flag("some_int64_set_via_hook",
+ [&](int64 value) {
+ some_int64_set_via_hook = value;
+ return true;
+ },
+ some_int64_set_via_hook, "some int64 set via hook"),
+ Flag("some_switch_set_directly", &some_switch_set_directly,
+ "some switch set directly"),
+ Flag("some_switch_set_via_hook",
+ [&](bool value) {
+ some_switch_set_via_hook = value;
+ return true;
+ },
+ some_switch_set_via_hook, "some switch set via hook"),
+ Flag("some_name_set_directly", &some_name_set_directly,
+ "some name set directly"),
+ Flag("some_name_set_via_hook",
+ [&](string value) {
+ some_name_set_via_hook = std::move(value);
+ return true;
+ },
+ some_name_set_via_hook, "some name set via hook"),
+ Flag("some_float_set_directly", &some_float_set_directly,
+ "some float set directly"),
+ Flag("some_float_set_via_hook",
+ [&](float value) {
+ some_float_set_via_hook = value;
+ return true;
+ },
+ some_float_set_via_hook, "some float set via hook"),
+ });
+
EXPECT_EQ(true, parsed_ok);
- EXPECT_EQ(20, some_int);
- EXPECT_EQ(214748364700, some_int64);
- EXPECT_EQ(true, some_switch);
- EXPECT_EQ("somethingelse", some_name);
- EXPECT_NEAR(42.0f, some_float, 1e-5f);
+ EXPECT_EQ(20, some_int32_set_directly);
+ EXPECT_EQ(50, some_int32_set_via_hook);
+ EXPECT_EQ(214748364700, some_int64_set_directly);
+ EXPECT_EQ(214748364710, some_int64_set_via_hook);
+ EXPECT_EQ(true, some_switch_set_directly);
+ EXPECT_EQ(false, some_switch_set_via_hook);
+ EXPECT_EQ("somethingelse", some_name_set_directly);
+ EXPECT_EQ("anythingelse", some_name_set_via_hook);
+ EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f);
+ EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f);
EXPECT_EQ(argc, 1);
}
@@ -107,6 +160,70 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
EXPECT_EQ(argc, 1);
}
+TEST(CommandLineFlagsTest, FailedInt32Hook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_int32=200"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int32", [](int32 value) { return false; }, 30,
+ "some int32")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedInt64Hook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_int64=200"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int64", [](int64 value) { return false; }, 30,
+ "some int64")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedFloatHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_float=200.0"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_float", [](float value) { return false; }, 30.0f,
+ "some float")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedBoolHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_switch=true"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_switch", [](bool value) { return false; }, false,
+ "some switch")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedStringHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_name=true"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok = Flags::Parse(
+ &argc, argv_array.data(),
+ {Flag("some_name", [](string value) { return false; }, "", "some name")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
// Return whether str==pat, but allowing any whitespace in pat
// to match zero or more whitespace characters in str.
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {