aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-17 14:34:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 14:39:05 -0700
commit7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 (patch)
tree16b1440865a9020411f73e51ac75080c2a24c729
parent0c144afecef6800589d255dd990a9a88e9f94b23 (diff)
Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag.
PiperOrigin-RevId: 162271241
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc335
-rw-r--r--tensorflow/core/util/command_line_flags.cc147
-rw-r--r--tensorflow/core/util/command_line_flags.h51
-rw-r--r--tensorflow/core/util/command_line_flags_test.cc163
4 files changed, 457 insertions, 239 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 5b4fb5d0a7..8bd3d0f040 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -23,167 +23,215 @@ limitations under the License.
namespace xla {
namespace legacy_flags {
-struct DebugOptionsFlags {
- string xla_generate_hlo_graph;
- bool xla_hlo_graph_addresses;
- bool xla_hlo_graph_layout;
- string xla_hlo_graph_path;
- bool xla_hlo_dump_as_graphdef;
- string xla_log_hlo_text;
- string xla_generate_hlo_text_to;
-
- string xla_disable_hlo_passes;
- bool xla_enable_fast_math;
- bool xla_llvm_enable_alias_scope_metadata;
- bool xla_llvm_enable_noalias_metadata;
- bool xla_llvm_enable_invariant_load_metadata;
- int32 xla_backend_optimization_level;
- bool xla_embed_ir_in_executable;
- string xla_dump_ir_to;
- string xla_dump_debug_json_to;
- bool xla_eliminate_hlo_implicit_broadcast;
-
- bool xla_cpu_multi_thread_eigen;
-
- string xla_gpu_cuda_data_dir;
- bool xla_gpu_ftz;
-
- bool xla_test_all_output_layouts;
- bool xla_test_all_input_layouts;
-
- string xla_backend_extra_options;
-};
-
namespace {
-DebugOptionsFlags* flag_values;
+DebugOptions* flag_values;
std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;
+namespace {
+void SetDebugOptionsDefaults(DebugOptions* flags) {
+ flags->set_xla_hlo_graph_path("/tmp/");
+ flags->set_xla_enable_fast_math(true);
+ flags->set_xla_llvm_enable_alias_scope_metadata(true);
+ flags->set_xla_llvm_enable_noalias_metadata(true);
+ flags->set_xla_llvm_enable_invariant_load_metadata(true);
+ flags->set_xla_backend_optimization_level(3);
+ flags->set_xla_cpu_multi_thread_eigen(true);
+ flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
+}
+} // namespace
+
// Allocates flag_values and flag_objects; this function must not be called more
// than once - its call done via call_once.
void AllocateFlags() {
- flag_values = new DebugOptionsFlags;
- flag_values->xla_generate_hlo_graph = "";
- flag_values->xla_hlo_graph_addresses = false;
- flag_values->xla_hlo_graph_layout = false;
- flag_values->xla_hlo_graph_path = "/tmp/";
- flag_values->xla_hlo_dump_as_graphdef = false;
- flag_values->xla_log_hlo_text = "";
- flag_values->xla_generate_hlo_text_to = "";
- flag_values->xla_disable_hlo_passes = "";
- flag_values->xla_enable_fast_math = true;
- flag_values->xla_llvm_enable_alias_scope_metadata = true;
- flag_values->xla_llvm_enable_noalias_metadata = true;
- flag_values->xla_llvm_enable_invariant_load_metadata = true;
- flag_values->xla_backend_optimization_level = 3;
- flag_values->xla_embed_ir_in_executable = false;
- flag_values->xla_dump_ir_to = "";
- flag_values->xla_dump_debug_json_to = "";
- flag_values->xla_eliminate_hlo_implicit_broadcast = false;
- flag_values->xla_cpu_multi_thread_eigen = true;
- flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
- flag_values->xla_gpu_ftz = false;
- flag_values->xla_test_all_output_layouts = false;
- flag_values->xla_backend_extra_options = "";
- flag_values->xla_test_all_input_layouts = false;
+ flag_values = new DebugOptions;
+
+ SetDebugOptionsDefaults(flag_values);
+
+ // Returns a lambda that calls "member_setter" on "flag_values" with the
+ // argument passed in to the lambda.
+ auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
+ return [member_setter](bool value) {
+ (flag_values->*member_setter)(value);
+ return true;
+ };
+ };
+
+ // Returns a lambda that calls "member_setter" on "flag_values" with the
+ // argument passed in to the lambda.
+ auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) {
+ return [member_setter](int32 value) {
+ (flag_values->*member_setter)(value);
+ return true;
+ };
+ };
+
+ // Returns a lambda that is a custom "sub-parser" for xla_disable_hlo_passes.
+ auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
+ std::vector<string> disabled_passes =
+ tensorflow::str_util::Split(comma_separated_values, ',');
+ for (const auto& passname : disabled_passes) {
+ flag_values->add_xla_disable_hlo_passes(passname);
+ }
+ return true;
+ };
+
+ // Returns a lambda that is a custom "sub-parser" for
+ // xla_backend_extra_options.
+ auto setter_for_xla_backend_extra_options =
+ [](string comma_separated_values) {
+ std::vector<string> extra_options_parts =
+ tensorflow::str_util::Split(comma_separated_values, ',');
+ auto* extra_options_map =
+ flag_values->mutable_xla_backend_extra_options();
+
+ // The flag contains a comma-separated list of options; some options
+ // have arguments following "=", some don't.
+ for (const auto& part : extra_options_parts) {
+ size_t eq_pos = part.find_first_of('=');
+ if (eq_pos == string::npos) {
+ (*extra_options_map)[part] = "";
+ } else {
+ string value = "";
+ if (eq_pos + 1 < part.size()) {
+ value = part.substr(eq_pos + 1);
+ }
+ (*extra_options_map)[part.substr(0, eq_pos)] = value;
+ }
+ }
+
+ return true;
+ };
flag_objects = new std::vector<tensorflow::Flag>(
{tensorflow::Flag(
- "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph,
+ "xla_generate_hlo_graph",
+ flag_values->mutable_xla_generate_hlo_graph(),
"HLO modules matching this regex will be dumped to a .dot file "
"throughout various stages in compilation."),
tensorflow::Flag(
- "xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses,
+ "xla_hlo_graph_addresses",
+ bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
+ flag_values->xla_hlo_graph_addresses(),
"With xla_generate_hlo_graph, show addresses of HLO ops in "
"graph dump."),
tensorflow::Flag(
- "xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout,
+ "xla_hlo_graph_layout",
+ bool_setter_for(&DebugOptions::set_xla_hlo_graph_layout),
+ flag_values->xla_hlo_graph_layout(),
"With xla_generate_hlo_graph, show layout of HLO ops in "
"graph dump."),
tensorflow::Flag(
- "xla_hlo_graph_path", &flag_values->xla_hlo_graph_path,
+ "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(),
"With xla_generate_hlo_graph, dump the graphs into this path."),
- tensorflow::Flag("xla_hlo_dump_as_graphdef",
- &flag_values->xla_hlo_dump_as_graphdef,
- "Dump HLO graphs as TensorFlow GraphDefs."),
tensorflow::Flag(
- "xla_log_hlo_text", &flag_values->xla_log_hlo_text,
+ "xla_hlo_dump_as_graphdef",
+ bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef),
+ flag_values->xla_hlo_dump_as_graphdef(),
+ "Dump HLO graphs as TensorFlow GraphDefs."),
+ tensorflow::Flag(
+ "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(),
"HLO modules matching this regex will be dumped to LOG(INFO). "),
tensorflow::Flag(
- "xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to,
+ "xla_generate_hlo_text_to",
+ flag_values->mutable_xla_generate_hlo_text_to(),
"Dump all HLO modules as text into the provided directory path."),
tensorflow::Flag(
- "xla_enable_fast_math", &flag_values->xla_enable_fast_math,
+ "xla_enable_fast_math",
+ bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
+ flag_values->xla_enable_fast_math(),
"Enable unsafe fast-math optimizations in the compiler; "
"this may produce faster code at the expense of some accuracy."),
- tensorflow::Flag("xla_llvm_enable_alias_scope_metadata",
- &flag_values->xla_llvm_enable_alias_scope_metadata,
- "In LLVM-based backends, enable the emission of "
- "!alias.scope metadata in the generated IR."),
- tensorflow::Flag("xla_llvm_enable_noalias_metadata",
- &flag_values->xla_llvm_enable_noalias_metadata,
- "In LLVM-based backends, enable the emission of "
- "!noalias metadata in the generated IR."),
- tensorflow::Flag("xla_llvm_enable_invariant_load_metadata",
- &flag_values->xla_llvm_enable_invariant_load_metadata,
- "In LLVM-based backends, enable the emission of "
- "!invariant.load metadata in "
- "the generated IR."),
+ tensorflow::Flag(
+ "xla_llvm_enable_alias_scope_metadata",
+ bool_setter_for(
+ &DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
+ flag_values->xla_llvm_enable_alias_scope_metadata(),
+ "In LLVM-based backends, enable the emission of "
+ "!alias.scope metadata in the generated IR."),
+ tensorflow::Flag(
+ "xla_llvm_enable_noalias_metadata",
+ bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
+ flag_values->xla_llvm_enable_noalias_metadata(),
+ "In LLVM-based backends, enable the emission of "
+ "!noalias metadata in the generated IR."),
+ tensorflow::Flag(
+ "xla_llvm_enable_invariant_load_metadata",
+ bool_setter_for(
+ &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
+ flag_values->xla_llvm_enable_invariant_load_metadata(),
+ "In LLVM-based backends, enable the emission of "
+ "!invariant.load metadata in "
+ "the generated IR."),
tensorflow::Flag(
"xla_backend_optimization_level",
- &flag_values->xla_backend_optimization_level,
+ int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
+ flag_values->xla_backend_optimization_level(),
"Numerical optimization level for the XLA compiler backend."),
tensorflow::Flag(
- "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
+ "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
"Comma-separated list of hlo passes to be disabled. These names "
"must exactly match the passes' names; no whitespace around "
"commas."),
- tensorflow::Flag("xla_embed_ir_in_executable",
- &flag_values->xla_embed_ir_in_executable,
- "Embed the compiler IR as a string in the executable."),
tensorflow::Flag(
- "xla_dump_ir_to", &flag_values->xla_dump_ir_to,
+ "xla_embed_ir_in_executable",
+ bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
+ flag_values->xla_embed_ir_in_executable(),
+ "Embed the compiler IR as a string in the executable."),
+ tensorflow::Flag(
+ "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(),
"Dump the compiler IR into this directory as individual files."),
- tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
- &flag_values->xla_eliminate_hlo_implicit_broadcast,
- "Eliminate implicit broadcasts when lowering user "
- "computations to HLO instructions; use explicit "
- "broadcast instead."),
- tensorflow::Flag("xla_cpu_multi_thread_eigen",
- &flag_values->xla_cpu_multi_thread_eigen,
- "When generating calls to Eigen in the CPU backend, "
- "use multi-threaded Eigen mode."),
+ tensorflow::Flag(
+ "xla_eliminate_hlo_implicit_broadcast",
+ bool_setter_for(
+ &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
+ flag_values->xla_eliminate_hlo_implicit_broadcast(),
+ "Eliminate implicit broadcasts when lowering user "
+ "computations to HLO instructions; use explicit "
+ "broadcast instead."),
+ tensorflow::Flag(
+ "xla_cpu_multi_thread_eigen",
+ bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
+ flag_values->xla_cpu_multi_thread_eigen(),
+ "When generating calls to Eigen in the CPU backend, "
+ "use multi-threaded Eigen mode."),
tensorflow::Flag("xla_gpu_cuda_data_dir",
- &flag_values->xla_gpu_cuda_data_dir,
+ flag_values->mutable_xla_gpu_cuda_data_dir(),
"If non-empty, speficies a local directory containing "
"ptxas and nvvm libdevice files; otherwise we use "
"those from runfile directories."),
- tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz,
+ tensorflow::Flag("xla_gpu_ftz",
+ bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
+ flag_values->xla_gpu_ftz(),
"If true, flush-to-zero semantics are enabled in the "
"code generated for GPUs."),
tensorflow::Flag(
- "xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to,
+ "xla_dump_debug_json_to",
+ flag_values->mutable_xla_dump_debug_json_to(),
"Dump compilation artifacts as JSON into this directory."),
- tensorflow::Flag("xla_test_all_output_layouts",
- &flag_values->xla_test_all_output_layouts,
- "Let ClientLibraryTestBase::ComputeAndCompare* test "
- "all permutations of output layouts. For example, with "
- "a 3D shape, all permutations of the set {0, 1, 2} are "
- "tried."),
- tensorflow::Flag("xla_test_all_input_layouts",
- &flag_values->xla_test_all_input_layouts,
- "Let ClientLibraryTestBase::ComputeAndCompare* test "
- "all permutations of *input* layouts. For example, for "
- "2 input arguments with 2D shape and 4D shape, the "
- "computation will run 2! * 4! times for every possible "
- "layouts"),
+ tensorflow::Flag(
+ "xla_test_all_output_layouts",
+ bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
+ flag_values->xla_test_all_output_layouts(),
+ "Let ClientLibraryTestBase::ComputeAndCompare* test "
+ "all permutations of output layouts. For example, with "
+ "a 3D shape, all permutations of the set {0, 1, 2} are "
+ "tried."),
+ tensorflow::Flag(
+ "xla_test_all_input_layouts",
+ bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
+ flag_values->xla_test_all_input_layouts(),
+ "Let ClientLibraryTestBase::ComputeAndCompare* test "
+ "all permutations of *input* layouts. For example, for "
+ "2 input arguments with 2D shape and 4D shape, the "
+ "computation will run 2! * 4! times for every possible "
+ "layouts"),
tensorflow::Flag("xla_backend_extra_options",
- &flag_values->xla_backend_extra_options,
+ setter_for_xla_backend_extra_options, "",
"Extra options to pass to a backend; "
"comma-separated list of 'key=val' strings (=val "
"may be omitted); no whitespace around commas.")});
-
ParseFlagsFromEnv(*flag_objects);
}
@@ -197,68 +245,7 @@ void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
xla::DebugOptions GetDebugOptionsFromFlags() {
std::call_once(flags_init, &AllocateFlags);
-
- DebugOptions options;
- options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph);
- options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses);
- options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout);
- options.set_xla_hlo_graph_path(flag_values->xla_hlo_graph_path);
- options.set_xla_hlo_dump_as_graphdef(flag_values->xla_hlo_dump_as_graphdef);
- options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text);
- options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to);
-
- std::vector<string> disabled_passes =
- tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ',');
- for (const auto& passname : disabled_passes) {
- options.add_xla_disable_hlo_passes(passname);
- }
-
- options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math);
- options.set_xla_backend_optimization_level(
- flag_values->xla_backend_optimization_level);
- options.set_xla_embed_ir_in_executable(
- flag_values->xla_embed_ir_in_executable);
- options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to);
- options.set_xla_eliminate_hlo_implicit_broadcast(
- flag_values->xla_eliminate_hlo_implicit_broadcast);
- options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
-
- options.set_xla_cpu_multi_thread_eigen(
- flag_values->xla_cpu_multi_thread_eigen);
- options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir);
- options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz);
- options.set_xla_llvm_enable_alias_scope_metadata(
- flag_values->xla_llvm_enable_alias_scope_metadata);
- options.set_xla_llvm_enable_noalias_metadata(
- flag_values->xla_llvm_enable_noalias_metadata);
- options.set_xla_llvm_enable_invariant_load_metadata(
- flag_values->xla_llvm_enable_invariant_load_metadata);
-
- options.set_xla_test_all_output_layouts(
- flag_values->xla_test_all_output_layouts);
- options.set_xla_test_all_input_layouts(
- flag_values->xla_test_all_input_layouts);
-
- std::vector<string> extra_options_parts =
- tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ',');
- auto* extra_options_map = options.mutable_xla_backend_extra_options();
-
- // The flag contains a comma-separated list of options; some options have
- // arguments following "=", some don't.
- for (const auto& part : extra_options_parts) {
- size_t eq_pos = part.find_first_of('=');
- if (eq_pos == string::npos) {
- (*extra_options_map)[part] = "";
- } else {
- string value = "";
- if (eq_pos + 1 < part.size()) {
- value = part.substr(eq_pos + 1);
- }
- (*extra_options_map)[part.substr(0, eq_pos)] = value;
- }
- }
-
- return options;
+ return *flag_values;
}
} // namespace legacy_flags
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 8373eb1f9e..3efc703faf 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -25,10 +25,11 @@ namespace tensorflow {
namespace {
bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- string* dst, bool* value_parsing_ok) {
+ const std::function<bool(string)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
- *dst = arg.ToString();
+ *value_parsing_ok = hook(arg.ToString());
return true;
}
@@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int32* dst, bool* value_parsing_ok) {
+ const std::function<bool(int32)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
+ int32 parsed_int32;
+ if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int32);
}
return true;
}
@@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- tensorflow::int64* dst, bool* value_parsing_ok) {
+ const std::function<bool(int64)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) {
+ int64 parsed_int64;
+ if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_int64);
}
return true;
}
@@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- bool* dst, bool* value_parsing_ok) {
+ const std::function<bool(bool)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag)) {
if (arg.empty()) {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
}
if (arg == "=true") {
- *dst = true;
+ *value_parsing_ok = hook(true);
return true;
} else if (arg == "=false") {
- *dst = false;
+ *value_parsing_ok = hook(false);
return true;
} else {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
@@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
}
bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- float* dst, bool* value_parsing_ok) {
+ const std::function<bool(float)>& hook,
+ bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
- if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
+ float parsed_float;
+ if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
+ } else {
+ *value_parsing_ok = hook(parsed_float);
}
return true;
}
@@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
- : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_([dst](int32 value) {
+ *dst = value;
+ return true;
+ }),
+ int32_default_for_display_(*dst),
+ usage_text_(usage_text) {}
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT64),
- int64_value_(dst),
+ int64_hook_([dst](int64 value) {
+ *dst = value;
+ return true;
+ }),
+ int64_default_for_display_(*dst),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, float* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_FLOAT),
+ float_hook_([dst](float value) {
+ *dst = value;
+ return true;
+ }),
+ float_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, bool* dst, const string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- bool_value_(dst),
+ bool_hook_([dst](bool value) {
+ *dst = value;
+ return true;
+ }),
+ bool_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, string* dst, const string& usage_text)
: name_(name),
type_(TYPE_STRING),
- string_value_(dst),
+ string_hook_([dst](string value) {
+ *dst = std::move(value);
+ return true;
+ }),
+ string_default_for_display_(*dst),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, float* dst, const string& usage_text)
+Flag::Flag(const char* name, std::function<bool(int32)> int32_hook,
+ int32 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT32),
+ int32_hook_(std::move(int32_hook)),
+ int32_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(int64)> int64_hook,
+ int64 default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_hook_(std::move(int64_hook)),
+ int64_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(float)> float_hook,
+ float default_value_for_display, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- float_value_(dst),
+ float_hook_(std::move(float_hook)),
+ float_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
+ bool default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_hook_(std::move(bool_hook)),
+ bool_default_for_display_(default_value_for_display),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::function<bool(string)> string_hook,
+ string default_value_for_display, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_hook_(std::move(string_hook)),
+ string_default_for_display_(std::move(default_value_for_display)),
usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
- if (type_ == TYPE_INT) {
- result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok);
+ if (type_ == TYPE_INT32) {
+ result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
} else if (type_ == TYPE_INT64) {
- result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok);
+ result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
} else if (type_ == TYPE_BOOL) {
- result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
+ result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
- result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
+ result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
} else if (type_ == TYPE_FLOAT) {
- result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
+ result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
}
return result;
}
@@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
for (const Flag& flag : flag_list) {
const char* type_name = "";
string flag_string;
- if (flag.type_ == Flag::TYPE_INT) {
+ if (flag.type_ == Flag::TYPE_INT32) {
type_name = "int32";
- flag_string =
- strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
+ flag.int32_default_for_display_);
} else if (flag.type_ == Flag::TYPE_INT64) {
type_name = "int64";
- flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
- static_cast<long long>(*flag.int64_value_));
+ flag_string = strings::Printf(
+ "--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(flag.int64_default_for_display_));
} else if (flag.type_ == Flag::TYPE_BOOL) {
type_name = "bool";
- flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
- *flag.bool_value_ ? "true" : "false");
+ flag_string =
+ strings::Printf("--%s=%s", flag.name_.c_str(),
+ flag.bool_default_for_display_ ? "true" : "false");
} else if (flag.type_ == Flag::TYPE_STRING) {
type_name = "string";
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
- flag.string_value_->c_str());
+ flag.string_default_for_display_.c_str());
} else if (flag.type_ == Flag::TYPE_FLOAT) {
type_name = "float";
- flag_string =
- strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
+ flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
+ flag.float_default_for_display_);
}
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
type_name, flag.usage_text_.c_str());
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index f349df16fd..121c7063c9 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#include <functional>
#include <string>
#include <vector>
#include "tensorflow/core/platform/types.h"
@@ -61,24 +62,58 @@ namespace tensorflow {
// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32* dst1, const string& usage_text);
- Flag(const char* name, int64* dst1, const string& usage_text);
+ Flag(const char* name, int32* dst, const string& usage_text);
+ Flag(const char* name, int64* dst, const string& usage_text);
Flag(const char* name, bool* dst, const string& usage_text);
Flag(const char* name, string* dst, const string& usage_text);
Flag(const char* name, float* dst, const string& usage_text);
+ // These constructors invoke a hook on a match instead of writing to a
+ // specific memory location. The hook may return false to signal a malformed
+ // or illegal value, which will then fail the command line parse.
+ //
+ // "default_value_for_display" is shown as the default value of this flag in
+ // Flags::Usage().
+ Flag(const char* name, std::function<bool(int32)> int32_hook,
+ int32 default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(int64)> int64_hook,
+ int64 default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(float)> float_hook,
+ float default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(bool)> bool_hook,
+ bool default_value_for_display, const string& usage_text);
+ Flag(const char* name, std::function<bool(string)> string_hook,
+ string default_value_for_display, const string& usage_text);
+
private:
friend class Flags;
bool Parse(string arg, bool* value_parsing_ok) const;
string name_;
- enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
- int* int_value_;
- int64* int64_value_;
- bool* bool_value_;
- string* string_value_;
- float* float_value_;
+ enum {
+ TYPE_INT32,
+ TYPE_INT64,
+ TYPE_BOOL,
+ TYPE_STRING,
+ TYPE_FLOAT,
+ } type_;
+
+ std::function<bool(int32)> int32_hook_;
+ int32 int32_default_for_display_;
+
+ std::function<bool(int64)> int64_hook_;
+ int64 int64_default_for_display_;
+
+ std::function<bool(float)> float_hook_;
+ float float_default_for_display_;
+
+ std::function<bool(bool)> bool_hook_;
+ bool bool_default_for_display_;
+
+ std::function<bool(string)> string_hook_;
+ string string_default_for_display_;
+
string usage_text_;
};
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index c86a70ec9d..6139c8e7bc 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -36,32 +36,85 @@ std::vector<char *> CharPointerVectorFromStrings(
} // namespace
TEST(CommandLineFlagsTest, BasicUsage) {
- int some_int = 10;
- int64 some_int64 = 21474836470; // max int32 is 2147483647
- bool some_switch = false;
- string some_name = "something";
- float some_float = -23.23f;
- int argc = 6;
+ int some_int32_set_directly = 10;
+ int some_int32_set_via_hook = 20;
+ int64 some_int64_set_directly = 21474836470; // max int32 is 2147483647
+ int64 some_int64_set_via_hook = 21474836479; // max int32 is 2147483647
+ bool some_switch_set_directly = false;
+ bool some_switch_set_via_hook = true;
+ string some_name_set_directly = "something_a";
+ string some_name_set_via_hook = "something_b";
+ float some_float_set_directly = -23.23f;
+ float some_float_set_via_hook = -25.23f;
std::vector<string> argv_strings = {"program_name",
- "--some_int=20",
- "--some_int64=214748364700",
- "--some_switch",
- "--some_name=somethingelse",
- "--some_float=42.0"};
+ "--some_int32_set_directly=20",
+ "--some_int32_set_via_hook=50",
+ "--some_int64_set_directly=214748364700",
+ "--some_int64_set_via_hook=214748364710",
+ "--some_switch_set_directly",
+ "--some_switch_set_via_hook=false",
+ "--some_name_set_directly=somethingelse",
+ "--some_name_set_via_hook=anythingelse",
+ "--some_float_set_directly=42.0",
+ "--some_float_set_via_hook=43.0"};
+ int argc = argv_strings.size();
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok =
- Flags::Parse(&argc, argv_array.data(),
- {Flag("some_int", &some_int, "some int"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name"),
- Flag("some_float", &some_float, "some float")});
+ bool parsed_ok = Flags::Parse(
+ &argc, argv_array.data(),
+ {
+ Flag("some_int32_set_directly", &some_int32_set_directly,
+ "some int32 set directly"),
+ Flag("some_int32_set_via_hook",
+ [&](int32 value) {
+ some_int32_set_via_hook = value;
+ return true;
+ },
+ some_int32_set_via_hook, "some int32 set via hook"),
+ Flag("some_int64_set_directly", &some_int64_set_directly,
+ "some int64 set directly"),
+ Flag("some_int64_set_via_hook",
+ [&](int64 value) {
+ some_int64_set_via_hook = value;
+ return true;
+ },
+ some_int64_set_via_hook, "some int64 set via hook"),
+ Flag("some_switch_set_directly", &some_switch_set_directly,
+ "some switch set directly"),
+ Flag("some_switch_set_via_hook",
+ [&](bool value) {
+ some_switch_set_via_hook = value;
+ return true;
+ },
+ some_switch_set_via_hook, "some switch set via hook"),
+ Flag("some_name_set_directly", &some_name_set_directly,
+ "some name set directly"),
+ Flag("some_name_set_via_hook",
+ [&](string value) {
+ some_name_set_via_hook = std::move(value);
+ return true;
+ },
+ some_name_set_via_hook, "some name set via hook"),
+ Flag("some_float_set_directly", &some_float_set_directly,
+ "some float set directly"),
+ Flag("some_float_set_via_hook",
+ [&](float value) {
+ some_float_set_via_hook = value;
+ return true;
+ },
+ some_float_set_via_hook, "some float set via hook"),
+ });
+
EXPECT_EQ(true, parsed_ok);
- EXPECT_EQ(20, some_int);
- EXPECT_EQ(214748364700, some_int64);
- EXPECT_EQ(true, some_switch);
- EXPECT_EQ("somethingelse", some_name);
- EXPECT_NEAR(42.0f, some_float, 1e-5f);
+ EXPECT_EQ(20, some_int32_set_directly);
+ EXPECT_EQ(50, some_int32_set_via_hook);
+ EXPECT_EQ(214748364700, some_int64_set_directly);
+ EXPECT_EQ(214748364710, some_int64_set_via_hook);
+ EXPECT_EQ(true, some_switch_set_directly);
+ EXPECT_EQ(false, some_switch_set_via_hook);
+ EXPECT_EQ("somethingelse", some_name_set_directly);
+ EXPECT_EQ("anythingelse", some_name_set_via_hook);
+ EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f);
+ EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f);
EXPECT_EQ(argc, 1);
}
@@ -107,6 +160,70 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
EXPECT_EQ(argc, 1);
}
+TEST(CommandLineFlagsTest, FailedInt32Hook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_int32=200"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int32", [](int32 value) { return false; }, 30,
+ "some int32")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedInt64Hook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_int64=200"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int64", [](int64 value) { return false; }, 30,
+ "some int64")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedFloatHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_float=200.0"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_float", [](float value) { return false; }, 30.0f,
+ "some float")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedBoolHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_switch=true"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_switch", [](bool value) { return false; }, false,
+ "some switch")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, FailedStringHook) {
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name", "--some_name=true"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok = Flags::Parse(
+ &argc, argv_array.data(),
+ {Flag("some_name", [](string value) { return false; }, "", "some name")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(argc, 1);
+}
+
// Return whether str==pat, but allowing any whitespace in pat
// to match zero or more whitespace characters in str.
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {