aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion_test.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-03-29 13:26:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-29 14:48:18 -0700
commit5149785eb7175a791acbd9859872e07439b968b6 (patch)
tree301dec5bfebc225557ce25638c41dfde99fdbaad /tensorflow/compiler/xla/service/instruction_fusion_test.cc
parentcd021175181431c57fdebe0a82e99ffabcc0897f (diff)
Fusion uses the same cost model for all backends when it comes to
rematerializing for fusion. Let subtypes override the default cost model. Change: 151626001
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc28
1 files changed, 21 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 2e3742ed75..a4c269f0eb 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -36,7 +36,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
EXPECT_EQ(broadcast2, computation->root_instruction());
}
@@ -55,7 +57,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@@ -73,7 +77,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_TRUE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@@ -91,7 +97,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_TRUE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@@ -106,7 +114,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
}
TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
@@ -120,7 +130,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
}
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
@@ -134,7 +146,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose1, computation->root_instruction());
EXPECT_FALSE(
- InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
}
} // namespace xla