diff options
author | 2017-07-17 14:34:18 -0700 | |
---|---|---|
committer | 2017-07-17 14:39:05 -0700 | |
commit | 7bf4e6cbaae9ca930aa17d058c94aa11119fc0c3 (patch) | |
tree | 16b1440865a9020411f73e51ac75080c2a24c729 /tensorflow/compiler/xla/legacy_flags | |
parent | 0c144afecef6800589d255dd990a9a88e9f94b23 (diff) |
Avoid the duplication in debug_options_flags.cc by generalizing tensorflow::Flag.
PiperOrigin-RevId: 162271241
Diffstat (limited to 'tensorflow/compiler/xla/legacy_flags')
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc | 335 |
1 files changed, 161 insertions, 174 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5b4fb5d0a7..8bd3d0f040 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -23,167 +23,215 @@ limitations under the License. namespace xla { namespace legacy_flags { -struct DebugOptionsFlags { - string xla_generate_hlo_graph; - bool xla_hlo_graph_addresses; - bool xla_hlo_graph_layout; - string xla_hlo_graph_path; - bool xla_hlo_dump_as_graphdef; - string xla_log_hlo_text; - string xla_generate_hlo_text_to; - - string xla_disable_hlo_passes; - bool xla_enable_fast_math; - bool xla_llvm_enable_alias_scope_metadata; - bool xla_llvm_enable_noalias_metadata; - bool xla_llvm_enable_invariant_load_metadata; - int32 xla_backend_optimization_level; - bool xla_embed_ir_in_executable; - string xla_dump_ir_to; - string xla_dump_debug_json_to; - bool xla_eliminate_hlo_implicit_broadcast; - - bool xla_cpu_multi_thread_eigen; - - string xla_gpu_cuda_data_dir; - bool xla_gpu_ftz; - - bool xla_test_all_output_layouts; - bool xla_test_all_input_layouts; - - string xla_backend_extra_options; -}; - namespace { -DebugOptionsFlags* flag_values; +DebugOptions* flag_values; std::vector<tensorflow::Flag>* flag_objects; std::once_flag flags_init; +namespace { +void SetDebugOptionsDefaults(DebugOptions* flags) { + flags->set_xla_hlo_graph_path("/tmp/"); + flags->set_xla_enable_fast_math(true); + flags->set_xla_llvm_enable_alias_scope_metadata(true); + flags->set_xla_llvm_enable_noalias_metadata(true); + flags->set_xla_llvm_enable_invariant_load_metadata(true); + flags->set_xla_backend_optimization_level(3); + flags->set_xla_cpu_multi_thread_eigen(true); + flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); +} +} // namespace + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. void AllocateFlags() { - flag_values = new DebugOptionsFlags; - flag_values->xla_generate_hlo_graph = ""; - flag_values->xla_hlo_graph_addresses = false; - flag_values->xla_hlo_graph_layout = false; - flag_values->xla_hlo_graph_path = "/tmp/"; - flag_values->xla_hlo_dump_as_graphdef = false; - flag_values->xla_log_hlo_text = ""; - flag_values->xla_generate_hlo_text_to = ""; - flag_values->xla_disable_hlo_passes = ""; - flag_values->xla_enable_fast_math = true; - flag_values->xla_llvm_enable_alias_scope_metadata = true; - flag_values->xla_llvm_enable_noalias_metadata = true; - flag_values->xla_llvm_enable_invariant_load_metadata = true; - flag_values->xla_backend_optimization_level = 3; - flag_values->xla_embed_ir_in_executable = false; - flag_values->xla_dump_ir_to = ""; - flag_values->xla_dump_debug_json_to = ""; - flag_values->xla_eliminate_hlo_implicit_broadcast = false; - flag_values->xla_cpu_multi_thread_eigen = true; - flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib"; - flag_values->xla_gpu_ftz = false; - flag_values->xla_test_all_output_layouts = false; - flag_values->xla_backend_extra_options = ""; - flag_values->xla_test_all_input_layouts = false; + flag_values = new DebugOptions; + + SetDebugOptionsDefaults(flag_values); + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) { + return [member_setter](bool value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) { + return [member_setter](int32 value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that is a custom "sub-parser" for xla_disable_hlo_passes. + auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { + std::vector<string> disabled_passes = + tensorflow::str_util::Split(comma_separated_values, ','); + for (const auto& passname : disabled_passes) { + flag_values->add_xla_disable_hlo_passes(passname); + } + return true; + }; + + // Returns a lambda that is a custom "sub-parser" for + // xla_backend_extra_options. + auto setter_for_xla_backend_extra_options = + [](string comma_separated_values) { + std::vector<string> extra_options_parts = + tensorflow::str_util::Split(comma_separated_values, ','); + auto* extra_options_map = + flag_values->mutable_xla_backend_extra_options(); + + // The flag contains a comma-separated list of options; some options + // have arguments following "=", some don't. + for (const auto& part : extra_options_parts) { + size_t eq_pos = part.find_first_of('='); + if (eq_pos == string::npos) { + (*extra_options_map)[part] = ""; + } else { + string value = ""; + if (eq_pos + 1 < part.size()) { + value = part.substr(eq_pos + 1); + } + (*extra_options_map)[part.substr(0, eq_pos)] = value; + } + } + + return true; + }; flag_objects = new std::vector<tensorflow::Flag>( {tensorflow::Flag( - "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, + "xla_generate_hlo_graph", + flag_values->mutable_xla_generate_hlo_graph(), "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), tensorflow::Flag( - "xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses, + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), "With xla_generate_hlo_graph, show addresses of HLO ops in " "graph dump."), tensorflow::Flag( - "xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout, + "xla_hlo_graph_layout", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_layout), + flag_values->xla_hlo_graph_layout(), "With xla_generate_hlo_graph, show layout of HLO ops in " "graph dump."), tensorflow::Flag( - "xla_hlo_graph_path", &flag_values->xla_hlo_graph_path, + "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag("xla_hlo_dump_as_graphdef", - &flag_values->xla_hlo_dump_as_graphdef, - "Dump HLO graphs as TensorFlow GraphDefs."), tensorflow::Flag( - "xla_log_hlo_text", &flag_values->xla_log_hlo_text, + "xla_hlo_dump_as_graphdef", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), + flag_values->xla_hlo_dump_as_graphdef(), + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), "HLO modules matching this regex will be dumped to LOG(INFO). "), tensorflow::Flag( - "xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to, + "xla_generate_hlo_text_to", + flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", &flag_values->xla_enable_fast_math, + "xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + flag_values->xla_enable_fast_math(), "Enable unsafe fast-math optimizations in the compiler; " "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag("xla_llvm_enable_alias_scope_metadata", - &flag_values->xla_llvm_enable_alias_scope_metadata, - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag("xla_llvm_enable_noalias_metadata", - &flag_values->xla_llvm_enable_noalias_metadata, - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag("xla_llvm_enable_invariant_load_metadata", - &flag_values->xla_llvm_enable_invariant_load_metadata, - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), tensorflow::Flag( "xla_backend_optimization_level", - &flag_values->xla_backend_optimization_level, + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), "Numerical optimization level for the XLA compiler backend."), tensorflow::Flag( - "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), - tensorflow::Flag("xla_embed_ir_in_executable", - &flag_values->xla_embed_ir_in_executable, - "Embed the compiler IR as a string in the executable."), tensorflow::Flag( - "xla_dump_ir_to", &flag_values->xla_dump_ir_to, + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag( + "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), "Dump the compiler IR into this directory as individual files."), - tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", - &flag_values->xla_eliminate_hlo_implicit_broadcast, - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag("xla_cpu_multi_thread_eigen", - &flag_values->xla_cpu_multi_thread_eigen, - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), + tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for( + &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), tensorflow::Flag("xla_gpu_cuda_data_dir", - &flag_values->xla_gpu_cuda_data_dir, + flag_values->mutable_xla_gpu_cuda_data_dir(), "If non-empty, speficies a local directory containing " "ptxas and nvvm libdevice files; otherwise we use " "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz, + tensorflow::Flag("xla_gpu_ftz", + bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), "If true, flush-to-zero semantics are enabled in the " "code generated for GPUs."), tensorflow::Flag( - "xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to, + "xla_dump_debug_json_to", + flag_values->mutable_xla_dump_debug_json_to(), "Dump compilation artifacts as JSON into this directory."), - tensorflow::Flag("xla_test_all_output_layouts", - &flag_values->xla_test_all_output_layouts, - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag("xla_test_all_input_layouts", - &flag_values->xla_test_all_input_layouts, - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), + tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of output layouts. For example, with " + "a 3D shape, all permutations of the set {0, 1, 2} are " + "tried."), + tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of *input* layouts. For example, for " + "2 input arguments with 2D shape and 4D shape, the " + "computation will run 2! * 4! times for every possible " + "layouts"), tensorflow::Flag("xla_backend_extra_options", - &flag_values->xla_backend_extra_options, + setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " "may be omitted); no whitespace around commas.")}); - ParseFlagsFromEnv(*flag_objects); } @@ -197,68 +245,7 @@ void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) { xla::DebugOptions GetDebugOptionsFromFlags() { std::call_once(flags_init, &AllocateFlags); - - DebugOptions options; - options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); - options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses); - options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout); - options.set_xla_hlo_graph_path(flag_values->xla_hlo_graph_path); - options.set_xla_hlo_dump_as_graphdef(flag_values->xla_hlo_dump_as_graphdef); - options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text); - options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to); - - std::vector<string> disabled_passes = - tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); - for (const auto& passname : disabled_passes) { - options.add_xla_disable_hlo_passes(passname); - } - - options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); - options.set_xla_backend_optimization_level( - flag_values->xla_backend_optimization_level); - options.set_xla_embed_ir_in_executable( - flag_values->xla_embed_ir_in_executable); - options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to); - options.set_xla_eliminate_hlo_implicit_broadcast( - flag_values->xla_eliminate_hlo_implicit_broadcast); - options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to); - - options.set_xla_cpu_multi_thread_eigen( - flag_values->xla_cpu_multi_thread_eigen); - options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir); - options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz); - options.set_xla_llvm_enable_alias_scope_metadata( - flag_values->xla_llvm_enable_alias_scope_metadata); - options.set_xla_llvm_enable_noalias_metadata( - flag_values->xla_llvm_enable_noalias_metadata); - options.set_xla_llvm_enable_invariant_load_metadata( - flag_values->xla_llvm_enable_invariant_load_metadata); - - options.set_xla_test_all_output_layouts( - flag_values->xla_test_all_output_layouts); - options.set_xla_test_all_input_layouts( - flag_values->xla_test_all_input_layouts); - - std::vector<string> extra_options_parts = - tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); - auto* extra_options_map = options.mutable_xla_backend_extra_options(); - - // The flag contains a comma-separated list of options; some options have - // arguments following "=", some don't. - for (const auto& part : extra_options_parts) { - size_t eq_pos = part.find_first_of('='); - if (eq_pos == string::npos) { - (*extra_options_map)[part] = ""; - } else { - string value = ""; - if (eq_pos + 1 < part.size()) { - value = part.substr(eq_pos + 1); - } - (*extra_options_map)[part.substr(0, eq_pos)] = value; - } - } - - return options; + return *flag_values; } } // namespace legacy_flags |