/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_module.h" #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" namespace xla { HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), unique_id_(next_unique_module_id_++) {} StatusOr HloModule::LaunderConstInstructionFromModule( const HloInstruction* hlo) { if (hlo == nullptr) { return nullptr; } TF_RET_CHECK(hlo->GetModule() == this); // TODO(b/78350259): Eliminate const laundering. return const_cast(hlo); } Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); TF_RETURN_IF_ERROR(schedule.Verify()); schedule_ = std::move(schedule); return Status::OK(); } HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); // If the module configuration has no entry layout computation set, create a // default one based on the program shape. if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } } if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } // Pick unique IDs for each instruction. for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } // Set unique id to this computation. CHECK_NE(computation->root_instruction()->unique_id(), -1) << "Root has no valid id: " << computation->ToString(); computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions // for computations and instructions created later. Also, set the // next_unique_id_ to the one greater than the max unique id of any // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); } if (next_unique_id_ < computation->unique_id() + 1) { next_unique_id_ = computation->unique_id() + 1; } } computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); } HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { auto it = std::find_if(computations_.begin(), computations_.end(), [&to_remove](const std::unique_ptr& comp) { return comp.get() == to_remove; }); TF_RET_CHECK(it->get() == to_remove); computations_.erase(it); return Status::OK(); } HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( const std::unordered_map& replacements) { // Replace all uses of non-canonical computations with their // representatives. std::vector> new_computations; new_computations.reserve(computations_.size()); for (std::unique_ptr& computation : computations_) { for (auto* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { instruction->set_to_apply(new_arg); } break; } case HloOpcode::kWhile: { HloComputation* new_condition = tensorflow::gtl::FindWithDefault( replacements, instruction->while_condition(), nullptr); if (new_condition != nullptr) { instruction->set_while_condition(new_condition); } HloComputation* new_body = tensorflow::gtl::FindWithDefault( replacements, instruction->while_body(), nullptr); if (new_body != nullptr) { instruction->set_while_body(new_body); } break; } case HloOpcode::kConditional: { HloComputation* new_true_computation = tensorflow::gtl::FindWithDefault( replacements, instruction->true_computation(), nullptr); if (new_true_computation != nullptr) { instruction->set_true_computation(new_true_computation); } HloComputation* new_false_computation = tensorflow::gtl::FindWithDefault( replacements, instruction->false_computation(), nullptr); if (new_false_computation != nullptr) { instruction->set_false_computation(new_false_computation); } break; } case HloOpcode::kSelectAndScatter: { HloComputation* new_select = tensorflow::gtl::FindWithDefault( replacements, instruction->select(), nullptr); if (new_select != nullptr) { instruction->set_select(new_select); } HloComputation* new_scatter = tensorflow::gtl::FindWithDefault( replacements, instruction->scatter(), nullptr); if (new_scatter != nullptr) { instruction->set_scatter(new_scatter); } break; } default: break; } } if (replacements.find(computation.get()) == replacements.end()) { new_computations.push_back(std::move(computation)); } } // Replace entry_computation if necessary. entry_computation_ = tensorflow::gtl::FindWithDefault( replacements, entry_computation_, entry_computation_); computations_ = std::move(new_computations); } string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; s << "HloModule " << name(); if (has_schedule()) { TF_CHECK_OK(schedule().Verify()); s << ", is_scheduled=true"; } s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } if (has_schedule() && schedule().is_computation_scheduled(computation)) { s << computation->ToString( options, schedule().sequence(computation).instructions()) << "\n\n"; } else { s << computation->ToString(options) << "\n\n"; } } return s.str(); } HloModuleProto HloModule::ToProto() const { HloModuleProto proto; proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { HloComputationProto computation_proto = computation->ToProto(); if (computation->name() == entry_computation_->name()) { *proto.mutable_program_shape() = computation_proto.program_shape(); } proto.add_computations()->Swap(&computation_proto); } if (has_schedule()) { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } return proto; } /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { VLOG(2) << "CreateFromProto()"; XLA_VLOG_LINES(2, proto.DebugString()); // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( expected_program_shape.parameters(i)) << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); absl::flat_hash_map computation_map; absl::flat_hash_map to_proto_id; std::vector> computations; HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { TF_ASSIGN_OR_RETURN( std::unique_ptr computation, HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); computation_map[computation_id] = computation.get(); to_proto_id[computation.get()] = computation_id; if (computation_id == proto.entry_computation_id()) { entry = computation.get(); } computations.push_back(std::move(computation)); } TF_RET_CHECK(entry != nullptr); auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), [&](const std::unique_ptr& a, const std::unique_ptr& b) { return to_proto_id[a.get()] < to_proto_id[b.get()]; }); // Add sorted computations to the module. for (auto& computation : computations) { bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. absl::flat_hash_set computation_names; absl::flat_hash_set instruction_names; absl::flat_hash_set computation_ids; absl::flat_hash_set instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) << "Computation id is not unique: " << computation->unique_id(); computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) << "Instruction id is not unique: " << instruction->unique_id(); instruction_ids.insert(instruction->unique_id()); } } if (proto.has_schedule()) { TF_ASSIGN_OR_RETURN( HloSchedule schedule, HloSchedule::CreateFromProto(module.get(), proto.schedule())); TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); } return std::move(module); } /* static */ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { TF_RET_CHECK(module.has_program_shape()) << "No program shape found in the proto"; const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( program_shape.parameters(i))); } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); return module_config; } namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given // subcomputation. bool IsUsedOutsideSubcomputation( const HloInstruction& hlo, const std::unordered_set& instructions_in_subcomputation) { for (HloInstruction* user : hlo.users()) { if (!instructions_in_subcomputation.count(user)) { return true; } } return false; } } // anonymous namespace HloInstruction* HloModule::OutlineExpressionFromComputation( absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation) { auto builder = HloComputation::Builder(outlined_computation_name); // A map from original instructions to their counterparts in the new outlined // function. std::unordered_map outlined_instructions; // A set that contains all instructions to be outlined. std::unordered_set instruction_set_to_outline( instructions_to_outline.begin(), instructions_to_outline.end()); std::vector arguments; std::vector outputs; int64 parameter_count = 0; for (HloInstruction* instruction_to_outline : instructions_to_outline) { // Clone the original instruction. HloInstruction* outlined_instruction = builder.AddInstruction(instruction_to_outline->Clone()); // Replace its operands to their counterparts in the new function. for (int64 operand_num = 0; operand_num < outlined_instruction->operand_count(); ++operand_num) { HloInstruction* old_operand = outlined_instruction->mutable_operand(operand_num); HloInstruction** operand_slot = &(outlined_instructions[old_operand]); if (*operand_slot == nullptr) { // Because instructions_to_outline is in topological order, if // old_operand is not in outlined_instructions, old_operand must be an // input of the outlined subcomputation and thus should be represented // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot)); } // Insert the new instruction into the outlined_instructions map. InsertOrDie(&outlined_instructions, instruction_to_outline, outlined_instruction); // Mark instruction_to_outline an output if it is used outside the // subcomputation or is the output of the original computation (i.e. used // externally). if (instruction_to_outline->user_count() == 0 || IsUsedOutsideSubcomputation(*instruction_to_outline, instruction_set_to_outline)) { outputs.push_back(instruction_to_outline); } } if (outputs.size() != 1) { string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } HloInstruction* output = outputs[0]; // Creates a call to the nested computation. HloComputation* nested_computation = AddEmbeddedComputation( builder.Build(FindOrDie(outlined_instructions, output))); HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall( output->shape(), arguments, nested_computation)); VLOG(2) << "Outlining the following instructions"; for (auto* instruction_to_outline : instructions_to_outline) { VLOG(2) << " " << instruction_to_outline->ToString(); } VLOG(2) << "as a call " << call->ToString(); VLOG(2) << "to " << nested_computation->ToString(); TF_CHECK_OK(output->ReplaceAllUsesWith(call)); for (auto i = instructions_to_outline.rbegin(); i != instructions_to_outline.rend(); ++i) { TF_CHECK_OK(computation->RemoveInstruction(*i)); } return call; } int64 HloModule::instruction_count() const { int64 n = 0; for (const auto& computation : computations_) { n += computation->instruction_count(); } return n; } std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). std::set nonroot_computations; for (auto& computation : computations_) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : instruction->called_computations()) { nonroot_computations.insert(called_computation); } } } // Keep track of computations which have already been added to the post // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set added_computations; std::vector post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : computation->MakeEmbeddedComputationsList()) { if (added_computations.count(embedded_computation) == 0) { post_order.push_back(embedded_computation); added_computations.insert(embedded_computation); } } // Root computations should only be encountered once. CHECK_EQ(0, added_computations.count(computation.get())); post_order.push_back(computation.get()); added_computations.insert(computation.get()); } } if (post_order.size() != computations_.size()) { for (HloComputation* computation : post_order) { LOG(ERROR) << "Post Order: " << computation->name() << " (" << computation->parent()->name() << ")"; } for (auto& computation : computations_) { LOG(ERROR) << "Computations: " << computation->name() << " (" << computation->parent()->name() << ")"; } LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() << " computation_count=" << computations_.size(); } return post_order; } std::vector HloModule::MakeNonfusionComputations() const { std::vector result; for (auto* c : computations()) { if (c->IsFusionComputation()) { continue; } result.push_back(c); } return result; } std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = absl::make_unique(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); return module; } HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, HloCloneContext* context) { HloComputation* new_computation; if (context != nullptr) { if ((new_computation = context->FindComputation(computation)) != nullptr) { return new_computation; } new_computation = AddEmbeddedComputation(computation->Clone(context->suffix(), context)); } else { new_computation = AddEmbeddedComputation(computation->Clone("")); } return new_computation; } uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); } HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); auto it = absl::c_find_if( computations_in_module, [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla