aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 11:40:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 11:43:56 -0700
commitd66adb41874acddfd9e01f46e064965ee39850ca (patch)
treeb243798b38649cb5ed278ba9dfcaf345c70e10eb /tensorflow/compiler/xla/service
parenta8481834bb881f67e7b9523480c28f5b987e62e8 (diff)
Simplify, test and document logic in instruction fusion that decides whether we
allow fusion when an operation needs to be duplicated. PiperOrigin-RevId: 194429279
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc166
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h17
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc156
4 files changed, 249 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index d55da3686c..f39bfb8012 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1206,6 +1206,7 @@ tf_cc_test(
":instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index b9ccfeddb5..dc1a39e9fa 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -128,11 +128,11 @@ namespace xla {
return false;
}
-// An "effectively unary" operation is one that has one "large"
+// An "effectively at most unary" operation is one that has at most one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
-bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
+bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
int64 output_rank = 0;
ShapeUtil::ForEachSubshape(
hlo->shape(),
@@ -156,66 +156,91 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
}
bool InstructionFusion::CanFuseOnAllPaths(
- const HloReachabilityMap& reachability_map, HloInstruction* producer,
- HloInstruction* consumer, DoNotFuseSet* do_not_fuse) {
- auto could_fuse_on_all_paths = [&] {
- // First check to see if we have already marked this producer as infeasible
- // to fuse into consumer.
- if (do_not_fuse->count(producer) > 0) {
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloReachabilityMap& reachability_map,
+ const DoNotFuseSet& do_not_fuse) {
+ if (consumer == producer) {
+ return true;
+ }
+ if (!consumer->IsFusable()) {
+ return false;
+ }
+ for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
+ auto* consumer_operand = consumer->mutable_operand(i);
+ // If the operand is not on a path to the producer, it doesn't matter
+ // whether it's fusable.
+ if (!reachability_map.IsReachable(producer, consumer_operand)) {
+ continue;
+ }
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
return false;
}
- // Make sure it is possible for producer and consumer to exist in a fusion
- // node.
- if (!producer->IsFusable() || !consumer->IsFusable()) {
+ // The producer is reachable from consumer_operand which means we need
+ // to be able to fuse consumer_operand into consumer in order for
+ // producer to be fusable into consumer on all paths.
+ // Perform the recursive step: make sure producer can be fused into
+ // consumer_operand on all paths.
+ if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map,
+ do_not_fuse)) {
return false;
}
- // We do an upward walk of the graph from consumer towards all paths which
- // lead to producer to find any unfusable paths.
- for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
- auto* consumer_operand = consumer->mutable_operand(i);
- if (consumer_operand == producer) {
- // This is the base case: our upward crawl ends but we need to make sure
- // that fusion from consumer can happen.
- if (!ShouldFuse(consumer, i)) {
- return false;
- }
- } else if (reachability_map.IsReachable(producer, consumer_operand)) {
- // The reachability map told us that consumer_operand is a node on the
- // path to producer. We need to further investigate from
- // consumer_operand.
-
- // First check if we have already ruled out fusing producer into
- // consumer_operand.
- if (do_not_fuse->count(consumer_operand) > 0) {
- return false;
- }
- // Make sure it is possible for consumer_operand to exist in a fusion
- // node.
- if (!consumer_operand->IsFusable()) {
- return false;
- }
- // The producer is reachable from consumer_operand which means we need
- // to be able to fuse consumer_operand into consumer in order for
- // producer to be fusable into consumer on all paths.
- if (!ShouldFuse(consumer, i)) {
- return false;
- }
- // Perform the recursive step: make sure producer can be fused into
- // consumer_operand on all paths.
- if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand,
- do_not_fuse)) {
- return false;
- }
+ }
+ return true;
+}
+
+InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable(
+ tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
+ auto reachability = computation_->ComputeReachability();
+
+ // Forbid fusion of producers that:
+ // a) Need to be duplicated, unless they can be fused into all consumers
+ // via all paths.
+ // b) Are more than unary, that is, fusing them would likely lead to an
+ // increase in memory bandwidth use.
+ //
+ // Note that if we allow fusion by these global rules, we may still forbid
+ // fusing operations that require duplication later depending on
+ // is_expensive_().
+ DoNotFuseSet do_not_fuse;
+ for (HloInstruction* consumer : post_order) {
+ for (HloInstruction* producer : consumer->operands()) {
+ if (do_not_fuse.count(producer) > 0) {
+ continue;
}
+
+ // If the producer is effectively not more than unary, duplicating it
+ // will not increase the number of relevant inputs read, as the fusion
+ // node will only need to read at most 1 relevant input (the input of
+ // the producer). In that case, we do not forbid fusion of the operation
+ // here.
+ if (EffectivelyAtMostUnary(producer)) {
+ continue;
+ }
+ // Otherwise we will forbid fusing the op unless we can fuse it into
+ // all of its consumers on all paths.
+ //
+ // That means, that for:
+ // A --> B (fusable)
+ // \-> C (non-fusable)
+ // A will be not allowed to be fused into B, as it cannot be fused into C.
+ //
+ // Similarly, for:
+ // A -------------> B
+ // \-> C -> D -/
+ // If:
+ // - A is fusable into B and C, and D is fusable into B
+ // - C is *not* fusable into D
+ // A will be not allowed to be fused into B, as it cannot be fused via
+ // all paths.
+ if (producer->IsFusable() &&
+ CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) {
+ continue;
+ }
+ do_not_fuse.insert(producer);
}
- return true;
- };
- if (could_fuse_on_all_paths()) {
- return true;
}
- // We couldn't fuse on all paths, record this result.
- do_not_fuse->insert(producer);
- return false;
+
+ return do_not_fuse;
}
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
@@ -244,36 +269,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
InsertOrDie(&post_order_index, post_order[i], i);
}
- DoNotFuseSet do_not_fuse;
- auto reachability = computation->ComputeReachability();
-
- auto cheap_to_duplicate = [this](HloInstruction* producer) {
- if (producer->opcode() == HloOpcode::kBroadcast) {
- return true;
- }
- if (producer->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsEffectiveScalar(producer->shape())) {
- return true;
- }
- if (EffectivelyUnary(producer)) {
- return true;
- }
- return false;
- };
-
- for (HloInstruction* consumer : post_order) {
- for (HloInstruction* producer : consumer->operands()) {
- if (cheap_to_duplicate(producer)) {
- continue;
- }
- if (CanFuseOnAllPaths(*reachability, producer, consumer,
- &do_not_fuse)) {
- CHECK_EQ(do_not_fuse.count(producer), 0);
- } else {
- CHECK_GT(do_not_fuse.count(producer), 0);
- }
- }
- }
+ DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order);
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 152d0886ee..2ea1fcf937 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -70,11 +70,11 @@ class InstructionFusion : public HloPassInterface {
virtual HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer);
- // An "effectively unary" operation is one that has one "large"
+ // An "effectively unary" operation is one that has at most one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
- bool EffectivelyUnary(HloInstruction* hlo);
+ bool EffectivelyAtMostUnary(HloInstruction* hlo);
// Returns true if fusing producer into consumer would cause producer to be
// duplicated. This is the case if producer has uses other than consumer.
@@ -95,11 +95,16 @@ class InstructionFusion : public HloPassInterface {
// The set of producers whose consumers we cannot fuse into.
using DoNotFuseSet = std::unordered_set<HloInstruction*>;
- // Whether or not we can fuse consumer into original_producer on all paths
+ // Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
- bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map,
- HloInstruction* producer, HloInstruction* consumer,
- DoNotFuseSet* do_not_fuse);
+ bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
+ const HloReachabilityMap& reachability_map,
+ const DoNotFuseSet& do_not_fuse);
+
+ // Computes the set of nodes that we do not want to fuse into any of their
+ // consumers based on a global analysis of the HLO graph.
+ DoNotFuseSet ComputeGloballyUnfusable(
+ tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 0fa2c95fb4..e78b99a80c 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
@@ -92,6 +93,161 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
.ValueOrDie());
}
+// Counts the number of HLO ops with a given op code in the specified module.
+static int Count(const HloModule& module, HloOpcode op) {
+ int count = 0;
+ for (const auto* computation : module.computations()) {
+ for (const auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == op) {
+ ++count;
+ }
+ }
+ }
+ return count;
+}
+
+TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add = f32[4,3]{1,0} add(p0, p0)
+ ROOT root = f32[4,3]{1,0} subtract(add, add)
+ })")
+ .ValueOrDie();
+ // Expect the add and subtraction to be fused.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+
+ // Make sure the add hasn't been duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
+ // Make sure we do not duplicate the add, as we cannot fuse through the rng.
+ //
+ // p0 -> add -------------------------> sub
+ // \-> abs1 -> rng -> abs2 -/
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add = f32[4,3]{1,0} add(p0, p0)
+ abs1 = f32[4,3]{1,0} abs(add)
+ rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform
+ abs2 = f32[4,3]{1,0} abs(rng)
+ ROOT root = f32[4,3]{1,0} subtract(abs2, add)
+ })")
+ .ValueOrDie();
+ // We expect abs2 to be fused into root.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+
+ // Make sure the add hasn't been duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
+
+ // Use a log node with a second consumer to break the fusion.
+ //
+ // p0 -> add -------------------------> sub
+ // \-> abs1 -> log -> abs2 -/
+ // \-> send
+ module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add = f32[4,3]{1,0} add(p0, p0)
+ abs1 = f32[4,3]{1,0} abs(add)
+ log = f32[4,3]{1,0} log(abs1)
+ send = f32[4,3]{1,0} send(log), channel_id=0
+ abs2 = f32[4,3]{1,0} abs(log)
+ ROOT root = f32[4,3]{1,0} subtract(abs2, add)
+ })")
+ .ValueOrDie();
+
+ // We expect abs2 to be fused into root and abs1 to be fused into log.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString();
+
+ // Make sure the add hasn't been duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
+
+ // Make sure we still fuse ops where one operand in the chain to the producer
+ // can't be fused.
+ //
+ // p0 ---> add1 -----------> sub
+ // \ \-> add2 -/
+ // \-> log -/
+ // \-> send
+ module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add1 = f32[4,3]{1,0} add(p0, p0)
+ log = f32[4,3]{1,0} log(p0)
+ send = f32[4,3]{1,0} send(log), channel_id=0
+ add2 = f32[4,3]{1,0} add(log, add1)
+ ROOT root = f32[4,3]{1,0} subtract(add1, add2)
+ })")
+ .ValueOrDie();
+
+ // Expect the add1 and add2 to be fused into root.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+
+ // Make sure we didn't duplicate any adds.
+ EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
+
+ // A variant of the above that allows the algorithm to put add2 into the set
+ // of unfusable ops to short-circuit the decision whether add1 should be fused
+ // into sub2.
+ //
+ // /---------------\
+ // p0 ---> add1 ---> add2 ------> sub2
+ // \------> sub1
+ // log -/
+ // \-> send
+ module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add1 = f32[4,3]{1,0} add(p0, p0)
+ add2 = f32[4,3]{1,0} add(add1, add1)
+ log = f32[4,3]{1,0} log(add2)
+ send = f32[4,3]{1,0} send(log), channel_id=0
+ sub1 = f32[4,3]{1,0} subtract(log, add2)
+ sub2 = f32[4,3]{1,0} subtract(add2, add1)
+ ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
+ })")
+ .ValueOrDie();
+
+ // Expect sub1 and sub2 to be fused into root.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+
+ // Make sure we didn't duplicate any adds.
+ EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
+}
+
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});