diff options
author | 2017-04-26 13:19:33 -0800 | |
---|---|---|
committer | 2017-04-26 14:30:44 -0700 | |
commit | 0ad55c0ffdb3a2c86881e791d34fbdf1aacb359f (patch) | |
tree | 89eb9f5aacf55b10f0664130d8602856e61743bb /tensorflow/compiler/xla/service/transpose_folding.cc | |
parent | b82cb8e93245b0de66794f8986db453d022ae341 (diff) |
[XLA] Run transpose_folding on nested computations
We only ran the pass on the entry computation which would make us lose out on
optimization opportunities. Visit all computations to find any potential
transpose folding opportunities.
Change: 154343660
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/transpose_folding.cc | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index cfb90e6e1d..a0c88c6bbc 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -76,8 +76,7 @@ using InstructionOperandsPair = // the parent HLO computation of `dot`. // // Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoDot(InstructionOperandsPair pair) { auto* dot = pair.first; std::vector<HloInstruction*> instructions_to_fuse(1, dot); for (const int64 operand_index : pair.second) { @@ -89,7 +88,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, return false; } - computation->CreateFusionInstruction( + dot->parent()->CreateFusionInstruction( instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); return true; } @@ -98,8 +97,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, // `computation` is the parent HLO computation of `convolution`. // // Returns whether the module is changed. -bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; // We only support fusing the RHS transpose into convolution. @@ -135,8 +133,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), convolution.mutable_operand(0), &transpose_operand, convolution.window(), new_dnums); - TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution, - std::move(new_conv))); + TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( + &convolution, std::move(new_conv))); return true; } @@ -152,8 +150,6 @@ TransposeFolding::TransposeFolding( StatusOr<bool> TransposeFolding::Run(HloModule* module) { // Modifying the graph while traversing is dangerous, so we find all folding // opportunities before actually folding them. - HloComputation* entry_computation = module->entry_computation(); - std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots; std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions; auto visit_fn = [this, &foldable_dots, @@ -175,14 +171,17 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) { } return tensorflow::Status::OK(); }; - TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + for (auto& comp : module->computations()) { + TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); + } bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair, entry_computation); + changed |= FoldTransposeIntoDot(pair); } for (InstructionOperandsPair& pair : foldable_convolutions) { - changed |= FoldTransposeIntoConvolution(pair, entry_computation); + changed |= FoldTransposeIntoConvolution(pair); } return changed; } |