aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/transpose_folding.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-04-26 13:19:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-26 14:30:44 -0700
commit0ad55c0ffdb3a2c86881e791d34fbdf1aacb359f (patch)
tree89eb9f5aacf55b10f0664130d8602856e61743bb /tensorflow/compiler/xla/service/transpose_folding.cc
parentb82cb8e93245b0de66794f8986db453d022ae341 (diff)
[XLA] Run transpose_folding on nested computations
We only ran the pass on the entry computation which would make us lose out on optimization opportunities. Visit all computations to find any potential transpose folding opportunities. Change: 154343660
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding.cc')
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc23
1 files changed, 11 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index cfb90e6e1d..a0c88c6bbc 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -76,8 +76,7 @@ using InstructionOperandsPair =
// the parent HLO computation of `dot`.
//
// Returns whether the module is changed.
-bool FoldTransposeIntoDot(InstructionOperandsPair pair,
- HloComputation* computation) {
+bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
auto* dot = pair.first;
std::vector<HloInstruction*> instructions_to_fuse(1, dot);
for (const int64 operand_index : pair.second) {
@@ -89,7 +88,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair,
return false;
}
- computation->CreateFusionInstruction(
+ dot->parent()->CreateFusionInstruction(
instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot);
return true;
}
@@ -98,8 +97,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair,
// `computation` is the parent HLO computation of `convolution`.
//
// Returns whether the module is changed.
-bool FoldTransposeIntoConvolution(InstructionOperandsPair pair,
- HloComputation* computation) {
+bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first;
// We only support fusing the RHS transpose into convolution.
@@ -135,8 +133,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair,
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
convolution.window(), new_dnums);
- TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution,
- std::move(new_conv)));
+ TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
+ &convolution, std::move(new_conv)));
return true;
}
@@ -152,8 +150,6 @@ TransposeFolding::TransposeFolding(
StatusOr<bool> TransposeFolding::Run(HloModule* module) {
// Modifying the graph while traversing is dangerous, so we find all folding
// opportunities before actually folding them.
- HloComputation* entry_computation = module->entry_computation();
-
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
auto visit_fn = [this, &foldable_dots,
@@ -175,14 +171,17 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
}
return tensorflow::Status::OK();
};
- TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn));
+
+ for (auto& comp : module->computations()) {
+ TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
+ }
bool changed = false;
for (InstructionOperandsPair& pair : foldable_dots) {
- changed |= FoldTransposeIntoDot(pair, entry_computation);
+ changed |= FoldTransposeIntoDot(pair);
}
for (InstructionOperandsPair& pair : foldable_convolutions) {
- changed |= FoldTransposeIntoConvolution(pair, entry_computation);
+ changed |= FoldTransposeIntoConvolution(pair);
}
return changed;
}