aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc148
1 files changed, 136 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 979ea79243..ec4234b8d9 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -27,7 +28,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
-using InstructionFusionTest = HloTestBase;
+using MultiOutputFusionTest = HloTestBase;
const char kModulePrefix[] = R"(
HloModule test_module
@@ -40,10 +41,10 @@ const char kModulePrefix[] = R"(
scalar_mul_computation {
scalar_lhs.1 = f32[] parameter(0)
scalar_rhs.1 = f32[] parameter(1)
- ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+ ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
})";
-TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
// Fusion with reduce instruction root and a sibling reduce instruction
// sharing the same input param.
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
@@ -72,7 +73,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
op::Tuple(op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[6400]{0} parameter(1)
@@ -99,7 +100,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
@@ -126,7 +127,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
// Two sibling fusions with reduce instruction roots sharing the same input
// param.
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
@@ -160,7 +161,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
op::Tuple(op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest,
+TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
// Multi-output fusion with two reduce instructions root and a sibling reduce
// instruction sharing the same input param.
@@ -193,7 +194,7 @@ TEST_F(InstructionFusionTest,
op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
}
-TEST_F(InstructionFusionTest,
+TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
// Verify that if we already have a multi-output fusion that we prefer to pick
// a reduce op from its operands for checking shape compatibility.
@@ -226,7 +227,7 @@ TEST_F(InstructionFusionTest,
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
-TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[6400]{0} parameter(0)
@@ -255,7 +256,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
@@ -282,7 +283,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
op::Tuple(op::Reduce(), op::Add()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -323,7 +324,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
}
-TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
@@ -349,5 +350,128 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f16[2,2,2]{2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f16[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[2,2,2]{2,1,0} parameter(0)
+ convert = f32[2,2,2]{2,1,0} convert(p0.2)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+ ENTRY reduce {
+ p0 = f16[2,2,2]{2,1,0} parameter(0)
+ p1 = f16[2,2,2]{2,1,0} parameter(1)
+ select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+ ENTRY reduce {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
+ constexpr int64 kNumParams = 200;
+ ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion);
+
+ // Compute
+ // p0 * p1,
+ // p0 * p1 + p1 * p2
+ // p0 * p1 + p1 * p2 + p2 * p3
+ // ...
+ // where each of the (pi * pj)'s is represented as a fusion node so that
+ // multi-output fusion will pay attention to it.
+ auto module = CreateNewModule();
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+
+ std::vector<HloInstruction*> params;
+ for (int64 i = 0; i < kNumParams; ++i) {
+ params.push_back(
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
+ }
+
+ // Creates a fusion node that calculates x*y.
+ auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
+ HloComputation::Builder sub_builder("subcomp");
+ auto* p0 = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "p"));
+ auto* p1 = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "p"));
+ sub_builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
+ HloComputation* subcomp =
+ module->AddEmbeddedComputation(sub_builder.Build());
+ return HloInstruction::CreateFusion(
+ shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
+ };
+
+ auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
+ for (int64 i = 2; i < kNumParams; ++i) {
+ sum = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, sum,
+ b.AddInstruction(make_fusion(params[i - 1], params[i]))));
+ }
+ auto computation = module->AddEntryComputation(b.Build());
+ EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ for (const HloInstruction* instr : computation->instructions()) {
+ EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
+ GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion)
+ << instr->ToString();
+ }
+}
+
} // namespace gpu
} // namespace xla