aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Thomas Joerg <tjoerg@google.com>2018-08-03 05:23:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 05:28:00 -0700
commitb2933c618260edc039fb8a7e2dce4d2e185f0892 (patch)
tree57aabd3d47f32b8f1bdd196c6859a255a8dea1c4 /tensorflow
parent37b48fac2c365c49373467abf5fc58c4678e700e (diff)
[XLA:GPU] Allow multi-output fusion of element-wise instructions, in addition to loop fusions.
PiperOrigin-RevId: 207253181
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc20
2 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c67dcbce77..c62bae0628 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -115,15 +115,23 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
// will be broadcasted and have not been observed to cause data locality issues.
// TODO(b/111977086): Improve reduce emitters to remove this limitation.
bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
+ std::vector<HloInstruction*> params;
+ if (instr->opcode() == HloOpcode::kFusion) {
+ params = instr->fused_parameters();
+ } else {
+ for (HloInstruction* operand : instr->operands()) {
+ params.push_back(operand);
+ }
+ }
int64 max_rank = 0;
const Layout* max_rank_layout;
- for (HloInstruction* param : instr->fused_parameters()) {
+ for (HloInstruction* param : params) {
if (ShapeUtil::Rank(param->shape()) > max_rank) {
max_rank = ShapeUtil::Rank(param->shape());
max_rank_layout = &param->shape().layout();
}
}
- return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) {
+ return c_all_of(params, [&](HloInstruction* param) {
return (ShapeUtil::Rank(param->shape()) < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
@@ -221,7 +229,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
const bool is_loop_fusion =
producer->opcode() == HloOpcode::kFusion &&
producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
- if (!is_loop_fusion) {
+ if (!producer->IsElementwise() && !is_loop_fusion) {
VLOG(3) << producer->name() << " is not a loop fusion.";
continue;
}
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index ec4234b8d9..14f157a5e5 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -256,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ exp = f32[2,2,2]{2,1,0} exponential(p0)
+ reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Exp()));
+}
+
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_add {