diff options
author | 2018-01-24 19:47:58 -0800 | |
---|---|---|
committer | 2018-01-24 19:51:53 -0800 | |
commit | 1a6216e61e804019cd64732d6f95d4d9bbedb594 (patch) | |
tree | b08c9be1600a553faaa91a087563ddf5eb0e1525 | |
parent | b25e892311fbdb308f89605ede30fce1b138c7f6 (diff) |
[XLA] Fix HloModule clone.
Currently the cloning of an instruction is usually "shallow": the called
computations of the instruction are reused in the clone. This mechanism is
useful when the hlo graph need to be modified in place (e.g. inliner, and some
hlo passes). One exception is the fusion instruction: it's always "deep"
copied, which means the fused computation is copied as well. So when we deep
cloning an HLO module, don't re-copy the fused computation, and do let the
instruction's clone function know where to put the copied fused computation.
PiperOrigin-RevId: 183181206
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 21 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_test.cc | 42 |
2 files changed, 60 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 58bb942211..99d8dd04e5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -523,7 +523,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { std::unordered_map<HloComputation*, HloComputation*> clone_map; for (auto& computation : computations_) { - auto cloned_computation = computation->Clone(suffix); + if (computation->IsFusionComputation()) { + // Cloning of a fused computation is handled by its fusion instruction. + continue; + } + + // When cloning a computation, pass in the new module, so that for any + // fusion instruction in this computation, the fused computation will be + // deep cloned to the new module. + auto cloned_computation = computation->Clone(suffix, module.get()); InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); if (entry_computation_ == computation.get()) { @@ -537,8 +545,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { for (auto* instruction : cloned_computation->instructions()) { // Rewrite instruction's called_computation to point to the cloned // computations. - instruction->ReplaceCalledComputations( - [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); }); + instruction->ReplaceCalledComputations([&](HloComputation* hlo) { + if (hlo->IsFusionComputation()) { + // Cloning of a fused computation has already been handled when its + // fusion instruction is cloned. So this hlo computation is already + // the cloned one. + return hlo; + } + return FindOrDie(clone_map, hlo); + }); } } return module; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 0f5d3dccb7..cd51fa4e85 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -105,6 +105,48 @@ TEST_F(HloModuleTest, CloneTest) { } } +TEST_F(HloModuleTest, CloneHasFusion) { + auto module = CreateNewModule(); + + // Create the fused computation. + HloComputation* fused_computation; + { + auto b = HloComputation::Builder("Fused"); + auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + b.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x)); + fused_computation = module->AddEmbeddedComputation(b.Build()); + } + + // Create the entry computation. + { + auto b = HloComputation::Builder("Entry"); + auto input = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + b.AddInstruction( + HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, + /*operands=*/{input}, fused_computation)); + module->AddEntryComputation(b.Build()); + } + + auto post_order = module->MakeComputationPostOrder(); + auto cloned_module = module->Clone("copy"); + auto post_order_copied = cloned_module->MakeComputationPostOrder(); + + EXPECT_EQ(post_order.size(), post_order_copied.size()); + for (auto origin = post_order.begin(), copied = post_order_copied.begin(); + origin != post_order.end() && copied != post_order_copied.end(); + ++origin, ++copied) { + if ((*origin)->name() == "Fused") { + // Clone of the fused computation is handled when its fusion instruction + // is cloned, which always use suffix ".clone". + EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name()); + } else { + EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); + } + } +} + TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. auto module = CreateNewModule(); |