#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace { using absl::StrAppend; using absl::StrCat; void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; XLA_VLOG_LINES(3, module.ToString()); } void DumpModuleProto(const HloModule& module, const string& dump_to, const string& pipeline_name, const string& pass_name) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static auto* const module_id_to_pass_number = new tensorflow::gtl::FlatMap(); tensorflow::mutex_lock lock(mu); const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; const string mod_name = SanitizeFileName( absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), pass_number, pipeline_name, pass_name)); TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), dump_to, mod_name)); } } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { run_called_ = true; VLOG(1) << "Running HLO pass pipeline " << name(); auto repeated_field = module->config().debug_options().xla_disable_hlo_passes(); tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), repeated_field.end()); if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_passes, ", "); } auto run_invariant_checkers = [this, module](const string& message) -> Status { for (auto& invariant_checker : invariant_checkers_) { VLOG(1) << " Invariant checker " << invariant_checker->name(); StatusOr changed_status = invariant_checker->Run(module); VLOG(1) << " Invariant checker done " << invariant_checker->name(); if (!changed_status.ok()) { VLOG(2) << "Module failed invariant check:"; XLA_VLOG_LINES(2, module->ToString()); return Status(changed_status.status().code(), StrCat(changed_status.status().error_message(), "\n\nFailed ", message)); } TF_RET_CHECK(!changed_status.ValueOrDie()) << "invariant checkers must not change the graph"; } return Status::OK(); }; string prefix = StrCat(name(), ": pipeline start"); bool changed = false; string message; TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("before running pipeline: ", name()))); const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), "pipeline_start"); } for (auto& pass : passes_) { if (disabled_passes.count(string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; } VLOG(1) << " HLO pass " << pass->name(); // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); DumpModuleGraph(*module, message); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), string(pass->name())); } changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } DumpModuleGraph(*module, prefix + ", pipeline end"); return changed; } } // namespace xla