aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc58
1 files changed, 48 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 652b5c7687..6fef720853 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -71,7 +73,6 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// In that case, the operand of the reduce needs to have the same shape
// as the other tuple operands, but also we need to compare the output
// shapes of the reduces.
- // TODO(tjoerg): Allow differences in fp precision.
auto* element_instr_1 = get_element_instr(instr1);
auto* element_instr_2 = get_element_instr(instr2);
if (element_instr_1->opcode() == HloOpcode::kReduce &&
@@ -80,8 +81,8 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
return false;
}
// The elementwise output shapes must be the same (including layout).
- return ShapeUtil::Equal(get_element_shape(element_instr_1),
- get_element_shape(element_instr_2));
+ return ShapeUtil::EqualIgnoringFpPrecision(
+ get_element_shape(element_instr_1), get_element_shape(element_instr_2));
}
namespace {
@@ -107,16 +108,34 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
return IsReductionToVector(*instr);
}
}
+
+// The code emitted for reduction suffers from poor data locality if the layouts
+// of input parameters differ. In such situtations it is beneficial not to fuse.
+// We consider input params with maximum rank only. Params with smaller ranks
+// will be broadcasted and have not been observed to cause data locality issues.
+// TODO(b/110927656): Improve reduce emitters to remove this limitation.
+bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
+ int64 max_rank = 0;
+ const Layout* max_rank_layout;
+ for (HloInstruction* param : instr->fused_parameters()) {
+ if (ShapeUtil::Rank(param->shape()) > max_rank) {
+ max_rank = ShapeUtil::Rank(param->shape());
+ max_rank_layout = &param->shape().layout();
+ }
+ }
+ return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) {
+ return (ShapeUtil::Rank(param->shape()) < max_rank) ||
+ (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
+ });
+}
+
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions.
return IsInputFusibleReduction(instr) ||
(instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- // TODO(b/110202584): bitcasts make nested fusions, GPU has no support
- // for nested fusions.
- instr->fused_expression_root()->opcode() != HloOpcode::kBitcast);
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
@@ -145,16 +164,22 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) {
return false;
}
+
// If we're fusing fusions only do it if the fusion kind matches. Loop fusions
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
// together too (the result being an input fusion) if we find cases where this
// improves things.
CHECK(instr1->opcode() == HloOpcode::kFusion);
- if (instr2->opcode() == HloOpcode::kFusion) {
- return instr1->fusion_kind() == instr2->fusion_kind();
+ if ((instr2->opcode() == HloOpcode::kFusion &&
+ instr1->fusion_kind() != instr2->fusion_kind()) ||
+ (instr2->opcode() != HloOpcode::kFusion &&
+ instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) {
+ return false;
}
- return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop;
+
+ // Do this check last, as it may be expensive.
+ return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2);
}
bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
@@ -176,29 +201,41 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
// fusions operands.
for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) {
if (consumer->user_count() == 0) {
+ VLOG(3) << consumer->name() << " has no users.";
continue;
}
if (!IsInputFusibleReduction(consumer)) {
+ VLOG(3) << consumer->name() << " is not an input-fusable reduction.";
continue;
}
+ VLOG(3) << consumer->name()
+ << " is a fusion candidate. Looking for fuseable operands.";
auto consumer_operands = consumer->operands();
for (size_t i = 0; i < consumer_operands.size(); ++i) {
HloInstruction* producer = consumer_operands[i];
if (!producer->IsFusable()) {
+ VLOG(3) << producer->name() << " is not fusable.";
continue;
}
const bool is_loop_fusion =
producer->opcode() == HloOpcode::kFusion &&
producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
if (!is_loop_fusion) {
+ VLOG(3) << producer->name() << " is not a loop fusion.";
continue;
}
if (!ShapesCompatibleForFusion(producer, consumer)) {
+ VLOG(3) << producer->name() << " has an incompatible shape.";
+ continue;
+ }
+ if (!ReduceFriendlyInputLayouts(producer)) {
+ VLOG(3) << producer->name() << " has inputs with mixed layouts.";
continue;
}
// If we have already decided to fuse this producer, skip it.
if (ContainsKey(to_fuse, producer)) {
+ VLOG(3) << producer->name() << " will be fused with another consumer.";
continue;
}
// Do not fuse a producer if the other operands of the fusion are
@@ -207,6 +244,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
+ VLOG(3) << producer->name() << " would introduce a cycle when fused.";
break;
}
to_fuse.insert(producer);