diff options
author | 2017-06-26 09:35:14 -0700 | |
---|---|---|
committer | 2017-06-26 09:42:32 -0700 | |
commit | a3a7d1ac38da8fec75ae5a0eaee743b065a9b85c (patch) | |
tree | f5353838a3405866ad5adada0eaed0fbb60e05f6 | |
parent | 6b7b01a9c5df50977476b3c2892a896d9934f381 (diff) |
[XLA] Move HLO dumping flags from service_flags to debug_options_flags
This also removes the duplication in the xla_generate_hlo_graph flag.
This CL also moves the actual dumping logic from Executable to the
hlo_graph_dumper namespace, where it belongs; this is in preparation for
removing the hlo_dumper callback altogether, since it isn't serving any role
beyond what a direct call to hlo_graph_dumper would have (b/62872831 has more
details).
PiperOrigin-RevId: 160154869
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc | 27 | ||||
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/service_flags.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/service_flags.h | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/executable.cc | 30 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/executable.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/xla.proto | 13 |
14 files changed, 101 insertions, 81 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 7f771bf601..f87cf7083e 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -25,6 +25,11 @@ namespace legacy_flags { struct DebugOptionsFlags { string xla_generate_hlo_graph; + bool xla_hlo_graph_addresses; + bool xla_hlo_graph_layout; + string xla_log_hlo_text; + string xla_generate_hlo_text_to; + string xla_disable_hlo_passes; bool xla_enable_fast_math; bool xla_llvm_enable_alias_scope_metadata; @@ -54,6 +59,10 @@ std::once_flag flags_init; void AllocateFlags() { flag_values = new DebugOptionsFlags; flag_values->xla_generate_hlo_graph = ""; + flag_values->xla_hlo_graph_addresses = false; + flag_values->xla_hlo_graph_layout = false; + flag_values->xla_log_hlo_text = ""; + flag_values->xla_generate_hlo_text_to = ""; flag_values->xla_disable_hlo_passes = ""; flag_values->xla_enable_fast_math = true; flag_values->xla_llvm_enable_alias_scope_metadata = true; @@ -74,6 +83,20 @@ void AllocateFlags() { "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), tensorflow::Flag( + "xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses, + "With xla_generate_hlo_graph, show addresses of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout, + "With xla_generate_hlo_graph, show layout of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_log_hlo_text", &flag_values->xla_log_hlo_text, + "HLO modules matching this regex will be dumped to LOG(INFO). "), + tensorflow::Flag( + "xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to, + "Dump all HLO modules as text into the provided directory path."), + tensorflow::Flag( "xla_enable_fast_math", &flag_values->xla_enable_fast_math, "Enable unsafe fast-math optimizations in the compiler; " "this may produce faster code at the expense of some accuracy."), @@ -141,6 +164,10 @@ xla::DebugOptions GetDebugOptionsFromFlags() { DebugOptions options; options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); + options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses); + options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout); + options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text); + options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to); std::vector<string> disabled_passes = tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.cc b/tensorflow/compiler/xla/legacy_flags/service_flags.cc index 41cb8d8bdf..90d30e7569 100644 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.cc @@ -36,34 +36,14 @@ static std::once_flag flags_init; static void AllocateFlags() { flags = new ServiceFlags; flags->xla_hlo_profile = false; - flags->xla_log_hlo_text = ""; - flags->xla_generate_hlo_graph = ""; - flags->xla_hlo_graph_addresses = false; - flags->xla_hlo_graph_layout = false; flags->xla_hlo_graph_for_compute_constant = false; flags->xla_dump_computations_to = ""; - flags->xla_dump_hlo_text_to = ""; flags->xla_dump_executions_to = ""; flag_list = new std::vector<tensorflow::Flag>({ tensorflow::Flag( "xla_hlo_profile", &flags->xla_hlo_profile, "Instrument the computation to collect per-HLO cycle counts"), tensorflow::Flag( - "xla_log_hlo_text", &flags->xla_log_hlo_text, - "If non-empty, print the text format of " - "HLO modules whose name partially matches this regex. E.g. " - "xla_log_hlo_text=.* will dump the text for every module."), - tensorflow::Flag( - "xla_generate_hlo_graph", &flags->xla_generate_hlo_graph, - "If non-empty, dump graph of HLO modules whose name partially " - "matches this regex. E.g. --xla_generate_hlo_graph=.* will dump " - "the graph of every module."), - tensorflow::Flag("xla_hlo_graph_addresses", - &flags->xla_hlo_graph_addresses, - "Show addresses of HLO ops in graph"), - tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout, - "Show layout of HLO ops in graph"), - tensorflow::Flag( "xla_hlo_graph_for_compute_constant", &flags->xla_hlo_graph_for_compute_constant, "If true, include hlo dumps of graphs from ComputeConstant." @@ -72,9 +52,6 @@ static void AllocateFlags() { &flags->xla_dump_computations_to, "Dumps computations that XLA executes into the provided " "directory path"), - tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to, - "Dumps HLO modules that XLA executes into the provided " - "directory path"), tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to, "Dumps parameters and results of computations that XLA " "executes into the provided directory path"), diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.h b/tensorflow/compiler/xla/legacy_flags/service_flags.h index d982506944..72d0c52402 100644 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.h @@ -34,23 +34,11 @@ void AppendServiceFlags(std::vector<tensorflow::Flag>* flag_list); typedef struct { bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle // counts - string xla_log_hlo_text; // If non-empty, print the text format of the HLO - // modules whose name partially - // matches this regex. E.g. xla_log_hlo_text=.* - // will dump the text for every module. - string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules - // whose name partially matches this regex. - // E.g. --xla_generate_hlo_graph=.* will dump - // the graph of every module. - bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph - bool xla_hlo_graph_layout; // Show layout of HLO ops in graph bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of // graphs from ComputeConstant. // Such graphs still need to be // matched via // xla_generate_hlo_graph. - string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is - // executed into the provided directory path string xla_dump_computations_to; // Dumps computations that XLA executes // into the provided directory path // Dumps parameters and results of computations that XLA executes into diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 427b382211..718a2d798c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1813,6 +1813,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 3a9f8dc79e..20eb1aea37 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -21,39 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/regexp.h" namespace xla { -/* static */ void Executable::DumpExecutedHlo( - const HloModule& module, const string& label, - const HloExecutionProfile* profile) { - VLOG(2) << "module name = " << module.name(); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - string generate_hlo_graph_regex; - if (!flags->xla_generate_hlo_graph.empty()) { - generate_hlo_graph_regex = flags->xla_generate_hlo_graph; - } else { - generate_hlo_graph_regex = - module.config().debug_options().xla_generate_hlo_graph(); - } - if (!generate_hlo_graph_regex.empty() && - RE2::PartialMatch(module.name(), generate_hlo_graph_regex)) { - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - flags->xla_hlo_graph_addresses, - flags->xla_hlo_graph_layout, profile); - } - if (!flags->xla_log_hlo_text.empty() && - RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!flags->xla_dump_hlo_text_to.empty()) { - hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); - } -} - StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>> Executable::ExecuteOnStreams( tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions> run_options, diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 291916cd9f..b36a44e19e 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/session.pb.h" @@ -49,10 +50,6 @@ class Executable { shape_size_function_(std::move(shape_size_function)) {} virtual ~Executable() {} - // Dumps the executed HLO according to service-associated flags. - static void DumpExecutedHlo(const HloModule& module, const string& label, - const HloExecutionProfile* profile); - // Enqueues the compilation result on the provided stream, passing the given // arguments. This call is blocking and returns after the execution is done. // @@ -240,7 +237,8 @@ StatusOr<ReturnT> Executable::ExecuteOnStreamWrapper( } } } - DumpExecutedHlo(module(), "Service::Execute", profile_ptr); + hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", + profile_ptr); } return return_value; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index dffb53320c..69166a8d13 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" using ::tensorflow::Env; using ::tensorflow::WriteStringToFile; @@ -593,6 +594,31 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); + LOG(INFO) << "dumping module '" << module.name() << "' to " << path; +} + +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile) { + VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); + string graph_url; + const DebugOptions& debug_options = module.config().debug_options(); + if (!debug_options.xla_generate_hlo_graph().empty() && + RE2::PartialMatch(module.name(), + debug_options.xla_generate_hlo_graph())) { + graph_url = DumpGraph(*module.entry_computation(), label, + debug_options.xla_hlo_graph_addresses(), + debug_options.xla_hlo_graph_layout(), profile); + } + if (!debug_options.xla_log_hlo_text().empty() && + RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { + LOG(INFO) << "HLO for module " << module.name(); + LOG(INFO) << "Label: " << label; + XLA_LOG_LINES(2, module.ToString()); + } + if (!debug_options.xla_generate_hlo_text_to().empty()) { + DumpText(module, label, debug_options.xla_generate_hlo_text_to()); + } + return graph_url; } } // namespace hlo_graph_dumper diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8ed50c3847..eb65cf2d32 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -41,6 +41,13 @@ class GraphRendererInterface { virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; }; +// Dump the given HLO module if a dump is requested in its debug options. Based +// on the debug options, either a graph dump, a text dump or both may be +// generated. If a graph dump is generated, the description (e.g. an URL) is +// returned; otherwise an empty string is returned. +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile = nullptr); + // Dumps a graph of the computation and returns a description of the rendered // graph (e.g., a URL) based on the renderer. The "best" renderer in the // registry is used. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 68441ef17f..2416f150e7 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -168,7 +168,8 @@ Service::CreateComputeConstantBackend() { /* static */ Compiler::HloDumper Service::MakeHloDumper() { return [](const HloModule& module, const string& label) { - return Executable::DumpExecutedHlo(module, label, /*profile=*/nullptr); + hlo_graph_dumper::MaybeDumpHloModule(module, label, + /*profile=*/nullptr); }; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 1def8ec12a..00b2858790 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -56,7 +56,7 @@ struct HloTestBase::EigenThreadPoolWrapper { HloTestBase::HloTestBase() : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { test_hlo_dumper_ = [](const HloModule& module, const string& label) { - return Executable::DumpExecutedHlo(module, label, /*profile=*/nullptr); + hlo_graph_dumper::MaybeDumpHloModule(module, label, /*profile=*/nullptr); }; VLOG(1) << "executing on platform " << backend_->platform()->Name(); } diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 535e5b605b..8ba76ddc1b 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,7 +36,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", @@ -187,6 +187,7 @@ cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 10efa9f3e8..100e19e6d8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,12 +63,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) { } // namespace xla int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } tensorflow::port::InitMain(argv[0], &argc, &argv); - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; - flags->xla_hlo_graph_layout = true; - tensorflow::gtl::ArraySlice<char*> args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] xla::tools::RealMain(args); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 850267d319..ebaea0511c 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" @@ -62,10 +63,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; + tensorflow::port::InitMain(argv[0], &argc, &argv); xla::legacy_flags::HloGraphDumperFlags* dumper_flags = xla::legacy_flags::GetHloGraphDumperFlags(); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 46dd28c04d..8db0c26da3 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -28,6 +28,19 @@ message DebugOptions { // dump *all* HLO modules. string xla_generate_hlo_graph = 1; + // Show addresses of HLO ops in graph dump. + bool xla_hlo_graph_addresses = 21; + + // Show layout of HLO ops in graph dump. + bool xla_hlo_graph_layout = 22; + + // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to + // dump *all* HLO modules. + string xla_log_hlo_text = 23; + + // Dump all HLO modules as text into the provided directory path. + string xla_generate_hlo_text_to = 24; + // List of HLO passes to disable. These names must exactly match the pass // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 2; |