aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 11:40:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 11:43:56 -0700
commitd66adb41874acddfd9e01f46e064965ee39850ca (patch)
treeb243798b38649cb5ed278ba9dfcaf345c70e10eb /tensorflow/compiler/xla/service/instruction_fusion.cc
parenta8481834bb881f67e7b9523480c28f5b987e62e8 (diff)
Simplify, test and document logic in instruction fusion that decides whether we
allow fusion when an operation needs to be duplicated. PiperOrigin-RevId: 194429279
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc166
1 files changed, 81 insertions, 85 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index b9ccfeddb5..dc1a39e9fa 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -128,11 +128,11 @@ namespace xla {
return false;
}
-// An "effectively unary" operation is one that has one "large"
+// An "effectively at most unary" operation is one that has at most one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
-bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
+bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
int64 output_rank = 0;
ShapeUtil::ForEachSubshape(
hlo->shape(),
@@ -156,66 +156,91 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
}
bool InstructionFusion::CanFuseOnAllPaths(
- const HloReachabilityMap& reachability_map, HloInstruction* producer,
- HloInstruction* consumer, DoNotFuseSet* do_not_fuse) {
- auto could_fuse_on_all_paths = [&] {
- // First check to see if we have already marked this producer as infeasible
- // to fuse into consumer.
- if (do_not_fuse->count(producer) > 0) {
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloReachabilityMap& reachability_map,
+ const DoNotFuseSet& do_not_fuse) {
+ if (consumer == producer) {
+ return true;
+ }
+ if (!consumer->IsFusable()) {
+ return false;
+ }
+ for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
+ auto* consumer_operand = consumer->mutable_operand(i);
+ // If the operand is not on a path to the producer, it doesn't matter
+ // whether it's fusable.
+ if (!reachability_map.IsReachable(producer, consumer_operand)) {
+ continue;
+ }
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
return false;
}
- // Make sure it is possible for producer and consumer to exist in a fusion
- // node.
- if (!producer->IsFusable() || !consumer->IsFusable()) {
+ // The producer is reachable from consumer_operand which means we need
+ // to be able to fuse consumer_operand into consumer in order for
+ // producer to be fusable into consumer on all paths.
+ // Perform the recursive step: make sure producer can be fused into
+ // consumer_operand on all paths.
+ if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map,
+ do_not_fuse)) {
return false;
}
- // We do an upward walk of the graph from consumer towards all paths which
- // lead to producer to find any unfusable paths.
- for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
- auto* consumer_operand = consumer->mutable_operand(i);
- if (consumer_operand == producer) {
- // This is the base case: our upward crawl ends but we need to make sure
- // that fusion from consumer can happen.
- if (!ShouldFuse(consumer, i)) {
- return false;
- }
- } else if (reachability_map.IsReachable(producer, consumer_operand)) {
- // The reachability map told us that consumer_operand is a node on the
- // path to producer. We need to further investigate from
- // consumer_operand.
-
- // First check if we have already ruled out fusing producer into
- // consumer_operand.
- if (do_not_fuse->count(consumer_operand) > 0) {
- return false;
- }
- // Make sure it is possible for consumer_operand to exist in a fusion
- // node.
- if (!consumer_operand->IsFusable()) {
- return false;
- }
- // The producer is reachable from consumer_operand which means we need
- // to be able to fuse consumer_operand into consumer in order for
- // producer to be fusable into consumer on all paths.
- if (!ShouldFuse(consumer, i)) {
- return false;
- }
- // Perform the recursive step: make sure producer can be fused into
- // consumer_operand on all paths.
- if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand,
- do_not_fuse)) {
- return false;
- }
+ }
+ return true;
+}
+
+InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable(
+ tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
+ auto reachability = computation_->ComputeReachability();
+
+ // Forbid fusion of producers that:
+ // a) Need to be duplicated, unless they can be fused into all consumers
+ // via all paths.
+ // b) Are more than unary, that is, fusing them would likely lead to an
+ // increase in memory bandwidth use.
+ //
+ // Note that if we allow fusion by these global rules, we may still forbid
+ // fusing operations that require duplication later depending on
+ // is_expensive_().
+ DoNotFuseSet do_not_fuse;
+ for (HloInstruction* consumer : post_order) {
+ for (HloInstruction* producer : consumer->operands()) {
+ if (do_not_fuse.count(producer) > 0) {
+ continue;
}
+
+ // If the producer is effectively not more than unary, duplicating it
+ // will not increase the number of relevant inputs read, as the fusion
+ // node will only need to read at most 1 relevant input (the input of
+ // the producer). In that case, we do not forbid fusion of the operation
+ // here.
+ if (EffectivelyAtMostUnary(producer)) {
+ continue;
+ }
+ // Otherwise we will forbid fusing the op unless we can fuse it into
+ // all of its consumers on all paths.
+ //
+ // That means, that for:
+ // A --> B (fusable)
+ // \-> C (non-fusable)
+ // A will be not allowed to be fused into B, as it cannot be fused into C.
+ //
+ // Similarly, for:
+ // A -------------> B
+ // \-> C -> D -/
+ // If:
+ // - A is fusable into B and C, and D is fusable into B
+ // - C is *not* fusable into D
+ // A will be not allowed to be fused into B, as it cannot be fused via
+ // all paths.
+ if (producer->IsFusable() &&
+ CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) {
+ continue;
+ }
+ do_not_fuse.insert(producer);
}
- return true;
- };
- if (could_fuse_on_all_paths()) {
- return true;
}
- // We couldn't fuse on all paths, record this result.
- do_not_fuse->insert(producer);
- return false;
+
+ return do_not_fuse;
}
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
@@ -244,36 +269,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
InsertOrDie(&post_order_index, post_order[i], i);
}
- DoNotFuseSet do_not_fuse;
- auto reachability = computation->ComputeReachability();
-
- auto cheap_to_duplicate = [this](HloInstruction* producer) {
- if (producer->opcode() == HloOpcode::kBroadcast) {
- return true;
- }
- if (producer->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsEffectiveScalar(producer->shape())) {
- return true;
- }
- if (EffectivelyUnary(producer)) {
- return true;
- }
- return false;
- };
-
- for (HloInstruction* consumer : post_order) {
- for (HloInstruction* producer : consumer->operands()) {
- if (cheap_to_duplicate(producer)) {
- continue;
- }
- if (CanFuseOnAllPaths(*reachability, producer, consumer,
- &do_not_fuse)) {
- CHECK_EQ(do_not_fuse.count(producer), 0);
- } else {
- CHECK_GT(do_not_fuse.count(producer), 0);
- }
- }
- }
+ DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order);
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all