aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-27 08:30:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-27 08:34:32 -0700
commitf938347aeba95afa55d5dd0d3f911689eff31821 (patch)
tree7b71fe2a83fe20bd6317bbae48220113c786d9cf /tensorflow/compiler/xla/service/instruction_fusion.cc
parentfeb4e648bf72ebc6e6dc377e95329a93821e5eba (diff)
Minor bugfix to HloInstruction::MergeFusionInstruction, and allow InstructionFusion::Fuse to be overridden by derived classes.
PiperOrigin-RevId: 166631382
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index edfcb0922d..d449c637b5 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -212,7 +212,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
bool changed = false;
-
+ module_ = module;
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
@@ -395,7 +395,6 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
VLOG(2) << "Fusing " << producer->ToString() << " into "
<< consumer->ToString();
-
auto kind = ChooseKind(producer, consumer);
if (consumer->opcode() == HloOpcode::kFusion) {
fusion_instruction = consumer;
@@ -407,8 +406,8 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
}
- fusion_instruction->FuseInstruction(producer);
+ fusion_instruction->FuseInstruction(producer);
return fusion_instruction;
}
@@ -423,13 +422,15 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
if (consumer->opcode() == HloOpcode::kFusion &&
consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
- consumer->fusion_kind() != HloInstruction::FusionKind::kInput) {
+ consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
+ consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
return false;
}
- // Cost condition: not fuse (expensive producers) and (consumers who reuse
- // operand elements).
- if (consumer->ReusesOperandElements(operand_index) &&
+ // Cost condition: not fuse (simple, expensive producers) and (consumers who
+ // reuse operand elements).
+ if (producer->opcode() != HloOpcode::kFusion &&
+ consumer->ReusesOperandElements(operand_index) &&
is_expensive_(*producer)) {
return false;
}