From 7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Jul 2017 14:34:18 -0700 Subject: Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag. PiperOrigin-RevId: 162271241 --- .../xla/legacy_flags/debug_options_flags.cc | 335 ++++++++++----------- tensorflow/core/util/command_line_flags.cc | 147 ++++++--- tensorflow/core/util/command_line_flags.h | 51 +++- tensorflow/core/util/command_line_flags_test.cc | 163 ++++++++-- 4 files changed, 457 insertions(+), 239 deletions(-) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5b4fb5d0a7..8bd3d0f040 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -23,167 +23,215 @@ limitations under the License. namespace xla { namespace legacy_flags { -struct DebugOptionsFlags { - string xla_generate_hlo_graph; - bool xla_hlo_graph_addresses; - bool xla_hlo_graph_layout; - string xla_hlo_graph_path; - bool xla_hlo_dump_as_graphdef; - string xla_log_hlo_text; - string xla_generate_hlo_text_to; - - string xla_disable_hlo_passes; - bool xla_enable_fast_math; - bool xla_llvm_enable_alias_scope_metadata; - bool xla_llvm_enable_noalias_metadata; - bool xla_llvm_enable_invariant_load_metadata; - int32 xla_backend_optimization_level; - bool xla_embed_ir_in_executable; - string xla_dump_ir_to; - string xla_dump_debug_json_to; - bool xla_eliminate_hlo_implicit_broadcast; - - bool xla_cpu_multi_thread_eigen; - - string xla_gpu_cuda_data_dir; - bool xla_gpu_ftz; - - bool xla_test_all_output_layouts; - bool xla_test_all_input_layouts; - - string xla_backend_extra_options; -}; - namespace { -DebugOptionsFlags* flag_values; +DebugOptions* flag_values; std::vector* flag_objects; std::once_flag flags_init; +namespace { +void SetDebugOptionsDefaults(DebugOptions* flags) { + flags->set_xla_hlo_graph_path("/tmp/"); + flags->set_xla_enable_fast_math(true); + flags->set_xla_llvm_enable_alias_scope_metadata(true); + flags->set_xla_llvm_enable_noalias_metadata(true); + flags->set_xla_llvm_enable_invariant_load_metadata(true); + flags->set_xla_backend_optimization_level(3); + flags->set_xla_cpu_multi_thread_eigen(true); + flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); +} +} // namespace + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. void AllocateFlags() { - flag_values = new DebugOptionsFlags; - flag_values->xla_generate_hlo_graph = ""; - flag_values->xla_hlo_graph_addresses = false; - flag_values->xla_hlo_graph_layout = false; - flag_values->xla_hlo_graph_path = "/tmp/"; - flag_values->xla_hlo_dump_as_graphdef = false; - flag_values->xla_log_hlo_text = ""; - flag_values->xla_generate_hlo_text_to = ""; - flag_values->xla_disable_hlo_passes = ""; - flag_values->xla_enable_fast_math = true; - flag_values->xla_llvm_enable_alias_scope_metadata = true; - flag_values->xla_llvm_enable_noalias_metadata = true; - flag_values->xla_llvm_enable_invariant_load_metadata = true; - flag_values->xla_backend_optimization_level = 3; - flag_values->xla_embed_ir_in_executable = false; - flag_values->xla_dump_ir_to = ""; - flag_values->xla_dump_debug_json_to = ""; - flag_values->xla_eliminate_hlo_implicit_broadcast = false; - flag_values->xla_cpu_multi_thread_eigen = true; - flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib"; - flag_values->xla_gpu_ftz = false; - flag_values->xla_test_all_output_layouts = false; - flag_values->xla_backend_extra_options = ""; - flag_values->xla_test_all_input_layouts = false; + flag_values = new DebugOptions; + + SetDebugOptionsDefaults(flag_values); + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) { + return [member_setter](bool value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) { + return [member_setter](int32 value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that is a custom "sub-parser" for xla_disable_hlo_passes. + auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { + std::vector disabled_passes = + tensorflow::str_util::Split(comma_separated_values, ','); + for (const auto& passname : disabled_passes) { + flag_values->add_xla_disable_hlo_passes(passname); + } + return true; + }; + + // Returns a lambda that is a custom "sub-parser" for + // xla_backend_extra_options. + auto setter_for_xla_backend_extra_options = + [](string comma_separated_values) { + std::vector extra_options_parts = + tensorflow::str_util::Split(comma_separated_values, ','); + auto* extra_options_map = + flag_values->mutable_xla_backend_extra_options(); + + // The flag contains a comma-separated list of options; some options + // have arguments following "=", some don't. + for (const auto& part : extra_options_parts) { + size_t eq_pos = part.find_first_of('='); + if (eq_pos == string::npos) { + (*extra_options_map)[part] = ""; + } else { + string value = ""; + if (eq_pos + 1 < part.size()) { + value = part.substr(eq_pos + 1); + } + (*extra_options_map)[part.substr(0, eq_pos)] = value; + } + } + + return true; + }; flag_objects = new std::vector( {tensorflow::Flag( - "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, + "xla_generate_hlo_graph", + flag_values->mutable_xla_generate_hlo_graph(), "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), tensorflow::Flag( - "xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses, + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), "With xla_generate_hlo_graph, show addresses of HLO ops in " "graph dump."), tensorflow::Flag( - "xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout, + "xla_hlo_graph_layout", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_layout), + flag_values->xla_hlo_graph_layout(), "With xla_generate_hlo_graph, show layout of HLO ops in " "graph dump."), tensorflow::Flag( - "xla_hlo_graph_path", &flag_values->xla_hlo_graph_path, + "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag("xla_hlo_dump_as_graphdef", - &flag_values->xla_hlo_dump_as_graphdef, - "Dump HLO graphs as TensorFlow GraphDefs."), tensorflow::Flag( - "xla_log_hlo_text", &flag_values->xla_log_hlo_text, + "xla_hlo_dump_as_graphdef", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), + flag_values->xla_hlo_dump_as_graphdef(), + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), "HLO modules matching this regex will be dumped to LOG(INFO). "), tensorflow::Flag( - "xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to, + "xla_generate_hlo_text_to", + flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", &flag_values->xla_enable_fast_math, + "xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + flag_values->xla_enable_fast_math(), "Enable unsafe fast-math optimizations in the compiler; " "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag("xla_llvm_enable_alias_scope_metadata", - &flag_values->xla_llvm_enable_alias_scope_metadata, - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag("xla_llvm_enable_noalias_metadata", - &flag_values->xla_llvm_enable_noalias_metadata, - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag("xla_llvm_enable_invariant_load_metadata", - &flag_values->xla_llvm_enable_invariant_load_metadata, - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), tensorflow::Flag( "xla_backend_optimization_level", - &flag_values->xla_backend_optimization_level, + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), "Numerical optimization level for the XLA compiler backend."), tensorflow::Flag( - "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), - tensorflow::Flag("xla_embed_ir_in_executable", - &flag_values->xla_embed_ir_in_executable, - "Embed the compiler IR as a string in the executable."), tensorflow::Flag( - "xla_dump_ir_to", &flag_values->xla_dump_ir_to, + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag( + "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), "Dump the compiler IR into this directory as individual files."), - tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", - &flag_values->xla_eliminate_hlo_implicit_broadcast, - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag("xla_cpu_multi_thread_eigen", - &flag_values->xla_cpu_multi_thread_eigen, - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), + tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for( + &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), tensorflow::Flag("xla_gpu_cuda_data_dir", - &flag_values->xla_gpu_cuda_data_dir, + flag_values->mutable_xla_gpu_cuda_data_dir(), "If non-empty, speficies a local directory containing " "ptxas and nvvm libdevice files; otherwise we use " "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz, + tensorflow::Flag("xla_gpu_ftz", + bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), "If true, flush-to-zero semantics are enabled in the " "code generated for GPUs."), tensorflow::Flag( - "xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to, + "xla_dump_debug_json_to", + flag_values->mutable_xla_dump_debug_json_to(), "Dump compilation artifacts as JSON into this directory."), - tensorflow::Flag("xla_test_all_output_layouts", - &flag_values->xla_test_all_output_layouts, - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag("xla_test_all_input_layouts", - &flag_values->xla_test_all_input_layouts, - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), + tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of output layouts. For example, with " + "a 3D shape, all permutations of the set {0, 1, 2} are " + "tried."), + tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of *input* layouts. For example, for " + "2 input arguments with 2D shape and 4D shape, the " + "computation will run 2! * 4! times for every possible " + "layouts"), tensorflow::Flag("xla_backend_extra_options", - &flag_values->xla_backend_extra_options, + setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " "may be omitted); no whitespace around commas.")}); - ParseFlagsFromEnv(*flag_objects); } @@ -197,68 +245,7 @@ void AppendDebugOptionsFlags(std::vector* flag_list) { xla::DebugOptions GetDebugOptionsFromFlags() { std::call_once(flags_init, &AllocateFlags); - - DebugOptions options; - options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); - options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses); - options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout); - options.set_xla_hlo_graph_path(flag_values->xla_hlo_graph_path); - options.set_xla_hlo_dump_as_graphdef(flag_values->xla_hlo_dump_as_graphdef); - options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text); - options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to); - - std::vector disabled_passes = - tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); - for (const auto& passname : disabled_passes) { - options.add_xla_disable_hlo_passes(passname); - } - - options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); - options.set_xla_backend_optimization_level( - flag_values->xla_backend_optimization_level); - options.set_xla_embed_ir_in_executable( - flag_values->xla_embed_ir_in_executable); - options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to); - options.set_xla_eliminate_hlo_implicit_broadcast( - flag_values->xla_eliminate_hlo_implicit_broadcast); - options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to); - - options.set_xla_cpu_multi_thread_eigen( - flag_values->xla_cpu_multi_thread_eigen); - options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir); - options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz); - options.set_xla_llvm_enable_alias_scope_metadata( - flag_values->xla_llvm_enable_alias_scope_metadata); - options.set_xla_llvm_enable_noalias_metadata( - flag_values->xla_llvm_enable_noalias_metadata); - options.set_xla_llvm_enable_invariant_load_metadata( - flag_values->xla_llvm_enable_invariant_load_metadata); - - options.set_xla_test_all_output_layouts( - flag_values->xla_test_all_output_layouts); - options.set_xla_test_all_input_layouts( - flag_values->xla_test_all_input_layouts); - - std::vector extra_options_parts = - tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); - auto* extra_options_map = options.mutable_xla_backend_extra_options(); - - // The flag contains a comma-separated list of options; some options have - // arguments following "=", some don't. - for (const auto& part : extra_options_parts) { - size_t eq_pos = part.find_first_of('='); - if (eq_pos == string::npos) { - (*extra_options_map)[part] = ""; - } else { - string value = ""; - if (eq_pos + 1 < part.size()) { - value = part.substr(eq_pos + 1); - } - (*extra_options_map)[part.substr(0, eq_pos)] = value; - } - } - - return options; + return *flag_values; } } // namespace legacy_flags diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 8373eb1f9e..3efc703faf 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -25,10 +25,11 @@ namespace tensorflow { namespace { bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - string* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { - *dst = arg.ToString(); + *value_parsing_ok = hook(arg.ToString()); return true; } @@ -36,14 +37,18 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int32* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + int32 parsed_int32; + if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int32); } return true; } @@ -52,14 +57,18 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - tensorflow::int64* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) { + int64 parsed_int64; + if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_int64); } return true; } @@ -68,19 +77,20 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag)) { if (arg.empty()) { - *dst = true; + *value_parsing_ok = hook(true); return true; } if (arg == "=true") { - *dst = true; + *value_parsing_ok = hook(true); return true; } else if (arg == "=false") { - *dst = false; + *value_parsing_ok = hook(false); return true; } else { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag @@ -94,14 +104,18 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - float* dst, bool* value_parsing_ok) { + const std::function& hook, + bool* value_parsing_ok) { *value_parsing_ok = true; if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { char extra; - if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + float parsed_float; + if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; + } else { + *value_parsing_ok = hook(parsed_float); } return true; } @@ -112,44 +126,107 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } // namespace Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) - : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {} + : name_(name), + type_(TYPE_INT32), + int32_hook_([dst](int32 value) { + *dst = value; + return true; + }), + int32_default_for_display_(*dst), + usage_text_(usage_text) {} Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) : name_(name), type_(TYPE_INT64), - int64_value_(dst), + int64_hook_([dst](int64 value) { + *dst = value; + return true; + }), + int64_default_for_display_(*dst), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_hook_([dst](float value) { + *dst = value; + return true; + }), + float_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, bool* dst, const string& usage_text) : name_(name), type_(TYPE_BOOL), - bool_value_(dst), + bool_hook_([dst](bool value) { + *dst = value; + return true; + }), + bool_default_for_display_(*dst), usage_text_(usage_text) {} Flag::Flag(const char* name, string* dst, const string& usage_text) : name_(name), type_(TYPE_STRING), - string_value_(dst), + string_hook_([dst](string value) { + *dst = std::move(value); + return true; + }), + string_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const string& usage_text) +Flag::Flag(const char* name, std::function int32_hook, + int32 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT32), + int32_hook_(std::move(int32_hook)), + int32_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function int64_hook, + int64 default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_INT64), + int64_hook_(std::move(int64_hook)), + int64_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function float_hook, + float default_value_for_display, const string& usage_text) : name_(name), type_(TYPE_FLOAT), - float_value_(dst), + float_hook_(std::move(float_hook)), + float_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function bool_hook, + bool default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_BOOL), + bool_hook_(std::move(bool_hook)), + bool_default_for_display_(default_value_for_display), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::function string_hook, + string default_value_for_display, const string& usage_text) + : name_(name), + type_(TYPE_STRING), + string_hook_(std::move(string_hook)), + string_default_for_display_(std::move(default_value_for_display)), usage_text_(usage_text) {} bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; - if (type_ == TYPE_INT) { - result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + if (type_ == TYPE_INT32) { + result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok); } else if (type_ == TYPE_INT64) { - result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok); + result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok); } else if (type_ == TYPE_BOOL) { - result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok); } else if (type_ == TYPE_STRING) { - result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok); } else if (type_ == TYPE_FLOAT) { - result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); + result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok); } return result; } @@ -203,26 +280,28 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { for (const Flag& flag : flag_list) { const char* type_name = ""; string flag_string; - if (flag.type_ == Flag::TYPE_INT) { + if (flag.type_ == Flag::TYPE_INT32) { type_name = "int32"; - flag_string = - strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_); + flag_string = strings::Printf("--%s=%d", flag.name_.c_str(), + flag.int32_default_for_display_); } else if (flag.type_ == Flag::TYPE_INT64) { type_name = "int64"; - flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(), - static_cast(*flag.int64_value_)); + flag_string = strings::Printf( + "--%s=%lld", flag.name_.c_str(), + static_cast(flag.int64_default_for_display_)); } else if (flag.type_ == Flag::TYPE_BOOL) { type_name = "bool"; - flag_string = strings::Printf("--%s=%s", flag.name_.c_str(), - *flag.bool_value_ ? "true" : "false"); + flag_string = + strings::Printf("--%s=%s", flag.name_.c_str(), + flag.bool_default_for_display_ ? "true" : "false"); } else if (flag.type_ == Flag::TYPE_STRING) { type_name = "string"; flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), - flag.string_value_->c_str()); + flag.string_default_for_display_.c_str()); } else if (flag.type_ == Flag::TYPE_FLOAT) { type_name = "float"; - flag_string = - strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_); + flag_string = strings::Printf("--%s=%f", flag.name_.c_str(), + flag.float_default_for_display_); } strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), type_name, flag.usage_text_.c_str()); diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index f349df16fd..121c7063c9 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H #define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#include #include #include #include "tensorflow/core/platform/types.h" @@ -61,24 +62,58 @@ namespace tensorflow { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32* dst1, const string& usage_text); - Flag(const char* name, int64* dst1, const string& usage_text); + Flag(const char* name, int32* dst, const string& usage_text); + Flag(const char* name, int64* dst, const string& usage_text); Flag(const char* name, bool* dst, const string& usage_text); Flag(const char* name, string* dst, const string& usage_text); Flag(const char* name, float* dst, const string& usage_text); + // These constructors invoke a hook on a match instead of writing to a + // specific memory location. The hook may return false to signal a malformed + // or illegal value, which will then fail the command line parse. + // + // "default_value_for_display" is shown as the default value of this flag in + // Flags::Usage(). + Flag(const char* name, std::function int32_hook, + int32 default_value_for_display, const string& usage_text); + Flag(const char* name, std::function int64_hook, + int64 default_value_for_display, const string& usage_text); + Flag(const char* name, std::function float_hook, + float default_value_for_display, const string& usage_text); + Flag(const char* name, std::function bool_hook, + bool default_value_for_display, const string& usage_text); + Flag(const char* name, std::function string_hook, + string default_value_for_display, const string& usage_text); + private: friend class Flags; bool Parse(string arg, bool* value_parsing_ok) const; string name_; - enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_; - int* int_value_; - int64* int64_value_; - bool* bool_value_; - string* string_value_; - float* float_value_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::function int32_hook_; + int32 int32_default_for_display_; + + std::function int64_hook_; + int64 int64_default_for_display_; + + std::function float_hook_; + float float_default_for_display_; + + std::function bool_hook_; + bool bool_default_for_display_; + + std::function string_hook_; + string string_default_for_display_; + string usage_text_; }; diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index c86a70ec9d..6139c8e7bc 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -36,32 +36,85 @@ std::vector CharPointerVectorFromStrings( } // namespace TEST(CommandLineFlagsTest, BasicUsage) { - int some_int = 10; - int64 some_int64 = 21474836470; // max int32 is 2147483647 - bool some_switch = false; - string some_name = "something"; - float some_float = -23.23f; - int argc = 6; + int some_int32_set_directly = 10; + int some_int32_set_via_hook = 20; + int64 some_int64_set_directly = 21474836470; // max int32 is 2147483647 + int64 some_int64_set_via_hook = 21474836479; // max int32 is 2147483647 + bool some_switch_set_directly = false; + bool some_switch_set_via_hook = true; + string some_name_set_directly = "something_a"; + string some_name_set_via_hook = "something_b"; + float some_float_set_directly = -23.23f; + float some_float_set_via_hook = -25.23f; std::vector argv_strings = {"program_name", - "--some_int=20", - "--some_int64=214748364700", - "--some_switch", - "--some_name=somethingelse", - "--some_float=42.0"}; + "--some_int32_set_directly=20", + "--some_int32_set_via_hook=50", + "--some_int64_set_directly=214748364700", + "--some_int64_set_via_hook=214748364710", + "--some_switch_set_directly", + "--some_switch_set_via_hook=false", + "--some_name_set_directly=somethingelse", + "--some_name_set_via_hook=anythingelse", + "--some_float_set_directly=42.0", + "--some_float_set_via_hook=43.0"}; + int argc = argv_strings.size(); std::vector argv_array = CharPointerVectorFromStrings(argv_strings); - bool parsed_ok = - Flags::Parse(&argc, argv_array.data(), - {Flag("some_int", &some_int, "some int"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name"), - Flag("some_float", &some_float, "some float")}); + bool parsed_ok = Flags::Parse( + &argc, argv_array.data(), + { + Flag("some_int32_set_directly", &some_int32_set_directly, + "some int32 set directly"), + Flag("some_int32_set_via_hook", + [&](int32 value) { + some_int32_set_via_hook = value; + return true; + }, + some_int32_set_via_hook, "some int32 set via hook"), + Flag("some_int64_set_directly", &some_int64_set_directly, + "some int64 set directly"), + Flag("some_int64_set_via_hook", + [&](int64 value) { + some_int64_set_via_hook = value; + return true; + }, + some_int64_set_via_hook, "some int64 set via hook"), + Flag("some_switch_set_directly", &some_switch_set_directly, + "some switch set directly"), + Flag("some_switch_set_via_hook", + [&](bool value) { + some_switch_set_via_hook = value; + return true; + }, + some_switch_set_via_hook, "some switch set via hook"), + Flag("some_name_set_directly", &some_name_set_directly, + "some name set directly"), + Flag("some_name_set_via_hook", + [&](string value) { + some_name_set_via_hook = std::move(value); + return true; + }, + some_name_set_via_hook, "some name set via hook"), + Flag("some_float_set_directly", &some_float_set_directly, + "some float set directly"), + Flag("some_float_set_via_hook", + [&](float value) { + some_float_set_via_hook = value; + return true; + }, + some_float_set_via_hook, "some float set via hook"), + }); + EXPECT_EQ(true, parsed_ok); - EXPECT_EQ(20, some_int); - EXPECT_EQ(214748364700, some_int64); - EXPECT_EQ(true, some_switch); - EXPECT_EQ("somethingelse", some_name); - EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(20, some_int32_set_directly); + EXPECT_EQ(50, some_int32_set_via_hook); + EXPECT_EQ(214748364700, some_int64_set_directly); + EXPECT_EQ(214748364710, some_int64_set_via_hook); + EXPECT_EQ(true, some_switch_set_directly); + EXPECT_EQ(false, some_switch_set_via_hook); + EXPECT_EQ("somethingelse", some_name_set_directly); + EXPECT_EQ("anythingelse", some_name_set_via_hook); + EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f); + EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f); EXPECT_EQ(argc, 1); } @@ -107,6 +160,70 @@ TEST(CommandLineFlagsTest, BadFloatValue) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, FailedInt32Hook) { + int argc = 2; + std::vector argv_strings = {"program_name", "--some_int32=200"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_int32", [](int32 value) { return false; }, 30, + "some int32")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedInt64Hook) { + int argc = 2; + std::vector argv_strings = {"program_name", "--some_int64=200"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_int64", [](int64 value) { return false; }, 30, + "some int64")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedFloatHook) { + int argc = 2; + std::vector argv_strings = {"program_name", "--some_float=200.0"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_float", [](float value) { return false; }, 30.0f, + "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedBoolHook) { + int argc = 2; + std::vector argv_strings = {"program_name", "--some_switch=true"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_switch", [](bool value) { return false; }, false, + "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, FailedStringHook) { + int argc = 2; + std::vector argv_strings = {"program_name", "--some_name=true"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = Flags::Parse( + &argc, argv_array.data(), + {Flag("some_name", [](string value) { return false; }, "", "some name")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(argc, 1); +} + // Return whether str==pat, but allowing any whitespace in pat // to match zero or more whitespace characters in str. static bool MatchWithAnyWhitespace(const string &str, const string &pat) { -- cgit v1.2.3