diff options
author | David Majnemer <majnemer@google.com> | 2017-03-29 13:26:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-29 14:48:18 -0700 |
commit | 5149785eb7175a791acbd9859872e07439b968b6 (patch) | |
tree | 301dec5bfebc225557ce25638c41dfde99fdbaad /tensorflow/compiler/xla/service/instruction_fusion_test.cc | |
parent | cd021175181431c57fdebe0a82e99ffabcc0897f (diff) |
Fusion uses the same cost model for all backends when it comes to
rematerializing for fusion. Let subtypes override the default cost model.
Change: 151626001
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion_test.cc | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 2e3742ed75..a4c269f0eb 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -36,7 +36,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(broadcast2, computation->root_instruction()); } @@ -55,7 +57,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -73,7 +77,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -91,7 +97,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -106,7 +114,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { @@ -120,7 +130,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { @@ -134,7 +146,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } } // namespace xla |