aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-29 21:24:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 21:27:20 -0700
commit9c509eedc3888d3846b2ab5ac2879268df9ff8cd (patch)
tree07a597f1409eaea8c38d7039e6580ff0f09e1b09 /tensorflow/compiler/xla/service/hlo_module.cc
parent3f2ba2edf62dc394cfcb4b2606f1638389aa92e2 (diff)
Introduced kDomain HLO instruction set isolation to bound connected sets of instructions with similar attributes (ie, sharding).
This CL simply adds the infrastructure, but leaves the wire-on to a separate CL. PiperOrigin-RevId: 198503625
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc74
1 files changed, 28 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index fbf1d58007..e63424c2df 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -496,7 +496,18 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
added_computations.insert(computation.get());
}
}
- CHECK_EQ(post_order.size(), computations_.size());
+ if (post_order.size() != computations_.size()) {
+ for (HloComputation* computation : post_order) {
+ LOG(ERROR) << "Post Order: " << computation->name() << " ("
+ << computation->parent()->name() << ")";
+ }
+ for (auto& computation : computations_) {
+ LOG(ERROR) << "Computations: " << computation->name() << " ("
+ << computation->parent()->name() << ")";
+ }
+ LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
+ << " computation_count=" << computations_.size();
+ }
return post_order;
}
@@ -517,54 +528,25 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
module->entry_computation_handle_ = entry_computation_handle_;
module->has_entry_computation_handle_ = has_entry_computation_handle_;
- std::unordered_map<HloComputation*, HloComputation*> clone_map;
- for (auto& computation : computations_) {
- if (computation->IsFusionComputation()) {
- // Cloning of a fused computation is handled by its fusion instruction.
- continue;
- }
-
- // When cloning a computation, pass in the new module, so that for any
- // fusion instruction in this computation, the fused computation will be
- // deep cloned to the new module.
- auto cloned_computation = computation->Clone(suffix, module.get());
- InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
-
- if (entry_computation_ == computation.get()) {
- module->AddEntryComputation(std::move(cloned_computation));
- } else {
- module->AddEmbeddedComputation(std::move(cloned_computation));
- }
- }
-
- for (auto& cloned_computation : module->computations_) {
- for (auto* instruction : cloned_computation->instructions()) {
- // Rewrite instruction's called_computation to point to the cloned
- // computations.
- instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
- if (hlo->IsFusionComputation()) {
- // Cloning of a fused computation has already been handled when its
- // fusion instruction is cloned. So this hlo computation is already
- // the cloned one.
- return hlo;
- }
- return FindOrDie(clone_map, hlo);
- });
- }
- }
+ HloCloneContext context(module.get(), suffix);
+ auto cloned_computation = entry_computation_->Clone(suffix, &context);
+ module->AddEntryComputation(std::move(cloned_computation));
return module;
}
-HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
- HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
- TF_CHECK_OK(
- clone->root_instruction()->Accept([this](HloInstruction* instruction) {
- instruction->ReplaceCalledComputations([this](HloComputation* callee) {
- return DeepCloneComputation(callee);
- });
- return Status::OK();
- }));
- return clone;
+HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
+ HloCloneContext* context) {
+ HloComputation* new_computation;
+ if (context != nullptr) {
+ if ((new_computation = context->FindComputation(computation)) != nullptr) {
+ return new_computation;
+ }
+ new_computation =
+ AddEmbeddedComputation(computation->Clone(context->suffix(), context));
+ } else {
+ new_computation = AddEmbeddedComputation(computation->Clone(""));
+ }
+ return new_computation;
}
uint64 HloModule::RandomNew64() const {