aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-27 08:30:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-27 08:34:32 -0700
commitf938347aeba95afa55d5dd0d3f911689eff31821 (patch)
tree7b71fe2a83fe20bd6317bbae48220113c786d9cf
parentfeb4e648bf72ebc6e6dc377e95329a93821e5eba (diff)
Minor bugfix to HloInstruction::MergeFusionInstruction, and allow InstructionFusion::Fuse to be overridden by derived classes.
PiperOrigin-RevId: 166631382
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc9
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc15
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h7
3 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 28ca915310..f5a081a9dc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -532,6 +532,8 @@ void HloInstruction::MergeFusionInstruction(
HloInstruction* instruction_to_merge) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
+ CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
+ operands().end());
// Clone the instruction from which to merge fused instructions.
std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone();
// Replace uses of fused parameters with the corresponding operand of the
@@ -563,6 +565,11 @@ void HloInstruction::MergeFusionInstruction(
}
CHECK_EQ(0, clone->user_count());
clone->DetachFromOperands();
+
+ if (GetModule()) {
+ TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(
+ clone->fused_instructions_computation()));
+ }
}
void HloInstruction::MergeFusionInstructionIntoMultiOutput(
@@ -2131,6 +2138,7 @@ using DFSStack =
// cycle was detected, and true otherwise.
inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack,
HloInstruction* child) {
+ CHECK(child != nullptr);
const int id = child->unique_id();
CHECK_GE(id, 0) << "instruction may not have a parent computation";
switch (visitor->GetVisitState(id)) {
@@ -2193,7 +2201,6 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting);
const size_t old_dfs_stack_size = dfs_stack.size();
-
for (HloInstruction* child : current_node->operands()) {
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index edfcb0922d..d449c637b5 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -212,7 +212,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
bool changed = false;
-
+ module_ = module;
std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
if (computation->IsFusionComputation()) {
@@ -395,7 +395,6 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
VLOG(2) << "Fusing " << producer->ToString() << " into "
<< consumer->ToString();
-
auto kind = ChooseKind(producer, consumer);
if (consumer->opcode() == HloOpcode::kFusion) {
fusion_instruction = consumer;
@@ -407,8 +406,8 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
}
- fusion_instruction->FuseInstruction(producer);
+ fusion_instruction->FuseInstruction(producer);
return fusion_instruction;
}
@@ -423,13 +422,15 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
if (consumer->opcode() == HloOpcode::kFusion &&
consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
- consumer->fusion_kind() != HloInstruction::FusionKind::kInput) {
+ consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
+ consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
return false;
}
- // Cost condition: not fuse (expensive producers) and (consumers who reuse
- // operand elements).
- if (consumer->ReusesOperandElements(operand_index) &&
+ // Cost condition: not fuse (simple, expensive producers) and (consumers who
+ // reuse operand elements).
+ if (producer->opcode() != HloOpcode::kFusion &&
+ consumer->ReusesOperandElements(operand_index) &&
is_expensive_(*producer)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index f6f37bb79b..3ac13ffda0 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -66,12 +66,15 @@ class InstructionFusion : public HloPassInterface {
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
const HloInstruction* consumer);
+ // Fuses producer into consumer.
+ virtual HloInstruction* Fuse(HloInstruction* producer,
+ HloInstruction* consumer);
+
// Current HloComputation instance the loop fuser is traversing.
HloComputation* computation_;
+ HloModule* module_;
private:
- HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);
-
// The set of producers whose consumers we cannot fuse into.
using DoNotFuseSet = std::unordered_set<HloInstruction*>;