aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-14 09:06:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 09:08:50 -0700
commit0c59fdb9497dba218857dbfab5616ee77fdb70b7 (patch)
tree61ec47469836a4b2f3fc6aac989b895faec6e59a /tensorflow/compiler/xla/service/instruction_fusion_test.cc
parent4b1fa0ccdcada19035fe9e685f2b63a5c7a78f21 (diff)
Pre-factoring: Fix overly specific test expectations to prepare for multi-output fusion.
PiperOrigin-RevId: 196514026
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 6dd8fa1ab0..cf9673a38a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
- .ValueOrDie());
+ .ValueOrDie())
+ << module->ToString();
}
// Counts the number of HLO ops with a given op code in the specified module.
@@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
+ << module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
@@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
+ op::Subtract(op::Parameter(), op::Parameter())))
+ << module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();