aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-04-26 13:19:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-26 14:30:44 -0700
commit0ad55c0ffdb3a2c86881e791d34fbdf1aacb359f (patch)
tree89eb9f5aacf55b10f0664130d8602856e61743bb /tensorflow/compiler
parentb82cb8e93245b0de66794f8986db453d022ae341 (diff)
[XLA] Run transpose_folding on nested computations
We only ran the pass on the entry computation which would make us lose out on optimization opportunities. Visit all computations to find any potential transpose folding opportunities. Change: 154343660
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc10
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc23
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc48
3 files changed, 66 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index e8378a7f44..c6e8a2f78b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -59,6 +59,11 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
+ // We can only do this if the HLO is unnested.
+ if (hlo.parent() != hlo.GetModule()->entry_computation()) {
+ return false;
+ }
+
// For certain types of Dot, we can call pre-canned BLAS gemm.
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
@@ -85,6 +90,11 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
}
bool ImplementedAsDnnConvolution(const HloInstruction& hlo) {
+ // We can only do this if the HLO is unnested.
+ if (hlo.parent() != hlo.GetModule()->entry_computation()) {
+ return false;
+ }
+
// Forward convolution.
if (hlo.opcode() == HloOpcode::kConvolution) {
const ConvolutionDimensionNumbers& dnums =
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index cfb90e6e1d..a0c88c6bbc 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -76,8 +76,7 @@ using InstructionOperandsPair =
// the parent HLO computation of `dot`.
//
// Returns whether the module is changed.
-bool FoldTransposeIntoDot(InstructionOperandsPair pair,
- HloComputation* computation) {
+bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
auto* dot = pair.first;
std::vector<HloInstruction*> instructions_to_fuse(1, dot);
for (const int64 operand_index : pair.second) {
@@ -89,7 +88,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair,
return false;
}
- computation->CreateFusionInstruction(
+ dot->parent()->CreateFusionInstruction(
instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot);
return true;
}
@@ -98,8 +97,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair,
// `computation` is the parent HLO computation of `convolution`.
//
// Returns whether the module is changed.
-bool FoldTransposeIntoConvolution(InstructionOperandsPair pair,
- HloComputation* computation) {
+bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first;
// We only support fusing the RHS transpose into convolution.
@@ -135,8 +133,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair,
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
convolution.window(), new_dnums);
- TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution,
- std::move(new_conv)));
+ TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
+ &convolution, std::move(new_conv)));
return true;
}
@@ -152,8 +150,6 @@ TransposeFolding::TransposeFolding(
StatusOr<bool> TransposeFolding::Run(HloModule* module) {
// Modifying the graph while traversing is dangerous, so we find all folding
// opportunities before actually folding them.
- HloComputation* entry_computation = module->entry_computation();
-
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
auto visit_fn = [this, &foldable_dots,
@@ -175,14 +171,17 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
}
return tensorflow::Status::OK();
};
- TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn));
+
+ for (auto& comp : module->computations()) {
+ TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
+ }
bool changed = false;
for (InstructionOperandsPair& pair : foldable_dots) {
- changed |= FoldTransposeIntoDot(pair, entry_computation);
+ changed |= FoldTransposeIntoDot(pair);
}
for (InstructionOperandsPair& pair : foldable_convolutions) {
- changed |= FoldTransposeIntoConvolution(pair, entry_computation);
+ changed |= FoldTransposeIntoConvolution(pair);
}
return changed;
}
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 6643f541da..c72d127ea8 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -41,9 +41,7 @@ class TransposeFoldingTest : public ::testing::Test {
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return gpu::ImplementedAsGemm(dot)
- ? candidate_operands
- : TransposeFolding::OperandIndices{};
+ return candidate_operands;
},
[](const HloInstruction& convolution,
const TransposeFolding::OperandIndices& candidate_operands) {
@@ -159,6 +157,50 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
EXPECT_EQ(6, callee_computation->instructions().size());
}
+TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
+ auto builder = HloComputation::Builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
+ /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
+ /*name=*/"y"));
+ HloInstruction* transpose_y =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot,
+ /*lhs=*/x, /*rhs=*/transpose_y));
+
+ HloModule module("test_module");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(dot));
+
+ HloInstruction* call = module.OutlineExpressionFromComputation(
+ {transpose_y, dot}, "outlined", entry_computation);
+
+ FoldTranspose(&module);
+
+ // Instructions after folding: x, y, and the fusion.
+ std::unordered_set<HloInstruction*> instruction_set;
+ for (auto& instruction : entry_computation->instructions()) {
+ instruction_set.insert(instruction.get());
+ }
+ CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
+ CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
+ CHECK_EQ(1, instruction_set.erase(call))
+ << "call is not in entry_computation.";
+ CHECK(instruction_set.empty())
+ << "entry_computation should contain exactly 3 instructions.";
+ HloInstruction* fusion =
+ call->called_computations().front()->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
+
+ // The fusion instruction should contain two parameters, one transpose and
+ // one dot.
+ EXPECT_EQ(4, fusion->fused_instructions().size());
+}
+
// Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
auto builder = HloComputation::Builder("entry_computation");