aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Thomas Joerg <tjoerg@google.com>2018-08-22 00:36:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 00:42:10 -0700
commit05f8ea8e9522a3027d4f3f7a54d716bfafed427a (patch)
treebb40b66a4261c69a80ef3fdfac9c5288d70ec87a
parent611459fd338e7109a1a2f7b21aa3ad5b1a5bd17b (diff)
[XLA:GPU] Do not fuse loop fusions with different output shapes.
PiperOrigin-RevId: 209724594
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc84
2 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 34a479b289..5575f6c0c6 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -187,6 +187,19 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
return false;
}
+ // Multi-output loop fusions must have equal output shapes to be lowered.
+ if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ Shape shape1 = instr1->IsMultiOutputFusion()
+ ? instr1->shape().tuple_shapes(0)
+ : instr1->shape();
+ Shape shape2 = instr2->IsMultiOutputFusion()
+ ? instr2->shape().tuple_shapes(0)
+ : instr2->shape();
+ if (!ShapeUtil::Equal(shape1, shape2)) {
+ return false;
+ }
+ }
+
// Do this check last, as it may be expensive.
return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2);
}
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 14f157a5e5..072f885bc1 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -256,6 +256,90 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Exp(), op::Add()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
ENTRY reduce {