diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-14 09:06:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 09:08:50 -0700 |
commit | 0c59fdb9497dba218857dbfab5616ee77fdb70b7 (patch) | |
tree | 61ec47469836a4b2f3fc6aac989b895faec6e59a /tensorflow/compiler/xla/service/instruction_fusion_test.cc | |
parent | 4b1fa0ccdcada19035fe9e685f2b63a5c7a78f21 (diff) |
Pre-factoring: Fix overly specific test expectations to prepare for multi-output fusion.
PiperOrigin-RevId: 196514026
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion_test.cc | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6dd8fa1ab0..cf9673a38a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } // Counts the number of HLO ops with a given op code in the specified module. @@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); // Make sure the add hasn't been duplicated. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); @@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); // Make sure we didn't duplicate any adds. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); |