aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-24 19:47:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 19:51:53 -0800
commit1a6216e61e804019cd64732d6f95d4d9bbedb594 (patch)
treeb08c9be1600a553faaa91a087563ddf5eb0e1525
parentb25e892311fbdb308f89605ede30fce1b138c7f6 (diff)
[XLA] Fix HloModule clone.
Currently the cloning of an instruction is usually "shallow": the called computations of the instruction are reused in the clone. This mechanism is useful when the hlo graph need to be modified in place (e.g. inliner, and some hlo passes). One exception is the fusion instruction: it's always "deep" copied, which means the fused computation is copied as well. So when we deep cloning an HLO module, don't re-copy the fused computation, and do let the instruction's clone function know where to put the copied fused computation. PiperOrigin-RevId: 183181206
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc42
2 files changed, 60 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 58bb942211..99d8dd04e5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -523,7 +523,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
std::unordered_map<HloComputation*, HloComputation*> clone_map;
for (auto& computation : computations_) {
- auto cloned_computation = computation->Clone(suffix);
+ if (computation->IsFusionComputation()) {
+ // Cloning of a fused computation is handled by its fusion instruction.
+ continue;
+ }
+
+ // When cloning a computation, pass in the new module, so that for any
+ // fusion instruction in this computation, the fused computation will be
+ // deep cloned to the new module.
+ auto cloned_computation = computation->Clone(suffix, module.get());
InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
if (entry_computation_ == computation.get()) {
@@ -537,8 +545,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
for (auto* instruction : cloned_computation->instructions()) {
// Rewrite instruction's called_computation to point to the cloned
// computations.
- instruction->ReplaceCalledComputations(
- [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); });
+ instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
+ if (hlo->IsFusionComputation()) {
+ // Cloning of a fused computation has already been handled when its
+ // fusion instruction is cloned. So this hlo computation is already
+ // the cloned one.
+ return hlo;
+ }
+ return FindOrDie(clone_map, hlo);
+ });
}
}
return module;
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 0f5d3dccb7..cd51fa4e85 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -105,6 +105,48 @@ TEST_F(HloModuleTest, CloneTest) {
}
}
+TEST_F(HloModuleTest, CloneHasFusion) {
+ auto module = CreateNewModule();
+
+ // Create the fused computation.
+ HloComputation* fused_computation;
+ {
+ auto b = HloComputation::Builder("Fused");
+ auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ b.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
+ fused_computation = module->AddEmbeddedComputation(b.Build());
+ }
+
+ // Create the entry computation.
+ {
+ auto b = HloComputation::Builder("Entry");
+ auto input = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ b.AddInstruction(
+ HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
+ /*operands=*/{input}, fused_computation));
+ module->AddEntryComputation(b.Build());
+ }
+
+ auto post_order = module->MakeComputationPostOrder();
+ auto cloned_module = module->Clone("copy");
+ auto post_order_copied = cloned_module->MakeComputationPostOrder();
+
+ EXPECT_EQ(post_order.size(), post_order_copied.size());
+ for (auto origin = post_order.begin(), copied = post_order_copied.begin();
+ origin != post_order.end() && copied != post_order_copied.end();
+ ++origin, ++copied) {
+ if ((*origin)->name() == "Fused") {
+ // Clone of the fused computation is handled when its fusion instruction
+ // is cloned, which always use suffix ".clone".
+ EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
+ } else {
+ EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
+ }
+ }
+}
+
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
// Create a module with a diamond call graph of computations.
auto module = CreateNewModule();