aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/instruction_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc70
1 files changed, 68 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 64ed3d748f..af6259ae83 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -73,6 +73,67 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
}
}
+// This function limits the maximum number of operands to a fusion.
+//
+// There's a cap on how many parameters we can pass to a CUDA kernel, but
+// exactly what that limit is is hazy, as it depends on (among other things) how
+// much GPU constant memory is in use for other purposes.
+//
+// Moreover, we don't even know at the point that we're running fusion how many
+// arguments the CUDA kernel for a fusion node will have: It depends on buffer
+// assignment, where we will decide which of the fusion's operands live in XLA's
+// big temp buffer versus in other allocations.
+//
+// As a heuristic, we simply cap the number of fusion operands plus outputs at
+// kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of
+// parameters to the kernel, working around the correctness problem.
+//
+// This limit is also often good for performance. In a fusion with many
+// operands, each GPU thread likely has to do a lot of work, and so possibly
+// uses a lot of registers, thus limiting occupancy.
+/*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge(
+ const HloInstruction* a, const HloInstruction* b) {
+ // Compute the number of outputs of the (possibly multi-output) fusion node
+ // we're considering creating.
+ //
+ // This isn't precise; we may be off by one if
+ // - We're creating a multi-output fusion out of two non-MOFs. Creating a
+ // MOF adds a new buffer, namely, the tuple buffer.
+ // - We're merging two MOFs. In this case, we should count the tuple buffer
+ // only once.
+ // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
+ // `a`. In this case the result of `a` is not part of the output of the
+ // fusion.
+ //
+ // But because this is a heuristic and our limit
+ // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
+ // big difference), we ignore this small inaccuracy in favor of simplicity.
+ int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) +
+ ShapeUtil::SubshapeCount(b->shape());
+
+ // The new fusion will have no more operands and outputs than
+ // producer_operands + consumer_operands - 1 + num_output_buffers
+ // (minus one because we may be fusing a producer->consumer edge between `a`
+ // and `b`).
+ //
+ // This fact may be enough to let us avoid having to compute the true total
+ // number of operands, which can be expensive.
+ if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <=
+ kMaxOperandsAndOutputsPerFusion) {
+ return false;
+ }
+
+ // Compute the precise number of operands to the new fusion.
+ tensorflow::gtl::FlatSet<const HloInstruction*> operands(
+ a->operands().begin(), a->operands().end());
+ operands.insert(b->operands().begin(), b->operands().end());
+ // If there's an edge between `a` and `b`, don't count it: We're fusing that
+ // producer -> consumer relationship.
+ operands.erase(a);
+ operands.erase(b);
+ return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
+}
+
bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
@@ -183,8 +244,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return true;
}
- return IsFusile(*producer) && IsFusile(*consumer) &&
- InstructionFusion::ShouldFuse(consumer, operand_index);
+ if (!IsFusile(*producer) || !IsFusile(*consumer) ||
+ !InstructionFusion::ShouldFuse(consumer, operand_index)) {
+ return false;
+ }
+
+ // We put this check last because it's potentially expensive.
+ return !FusionWouldBeTooLarge(consumer, producer);
}
bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,