aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc94
1 files changed, 82 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 979ea79243..451e49f23a 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -27,7 +27,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
-using InstructionFusionTest = HloTestBase;
+using MultiOutputFusionTest = HloTestBase;
const char kModulePrefix[] = R"(
HloModule test_module
@@ -40,10 +40,10 @@ const char kModulePrefix[] = R"(
scalar_mul_computation {
scalar_lhs.1 = f32[] parameter(0)
scalar_rhs.1 = f32[] parameter(1)
- ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+ ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
})";
-TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
// Fusion with reduce instruction root and a sibling reduce instruction
// sharing the same input param.
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
@@ -72,7 +72,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
op::Tuple(op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[6400]{0} parameter(1)
@@ -99,7 +99,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
@@ -126,7 +126,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
// Two sibling fusions with reduce instruction roots sharing the same input
// param.
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
@@ -160,7 +160,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
op::Tuple(op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest,
+TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
// Multi-output fusion with two reduce instructions root and a sibling reduce
// instruction sharing the same input param.
@@ -193,7 +193,7 @@ TEST_F(InstructionFusionTest,
op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest,
+TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
// Verify that if we already have a multi-output fusion that we prefer to pick
// a reduce op from its operands for checking shape compatibility.
@@ -226,7 +226,7 @@ TEST_F(InstructionFusionTest,
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[6400]{0} parameter(0)
@@ -255,7 +255,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
@@ -282,7 +282,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
op::Tuple(op::Reduce(), op::Add()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -323,7 +323,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
@@ -349,5 +349,75 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f16[2,2,2]{2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f16[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[2,2,2]{2,1,0} parameter(0)
+ convert = f32[2,2,2]{2,1,0} convert(p0.2)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+ ENTRY reduce {
+ p0 = f16[2,2,2]{2,1,0} parameter(0)
+ p1 = f16[2,2,2]{2,1,0} parameter(1)
+ select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+ ENTRY reduce {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
} // namespace gpu
} // namespace xla