diff options
author | 2017-02-15 13:28:20 -0800 | |
---|---|---|
committer | 2017-02-15 13:50:45 -0800 | |
commit | 9490027f2dd44ec367dc0e675921bd2404fb71c0 (patch) | |
tree | ee8fc63dccd8aea632b04a6a780dd3bdec60f00c | |
parent | cf3a2356e81b09ee2f3b8048c1b84bb5d12a3a4e (diff) |
[XLA] BufferLiveness: eliminate false interference between loop fusion instructions which use the same tuple-shaped operand (but access different tuple elements).
Change: 147636899
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_liveness.cc | 35 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_liveness_test.cc | 28 |
2 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 60415d5c06..6d8edabf73 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -100,11 +100,39 @@ namespace { // 'operand'. Returns true otherwise. // Precondition: 'operand' is an operand of 'user'. bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user) { + HloInstruction* user, + const TuplePointsToAnalysis& points_to_analysis) { if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { // GetTupleElement instructions only access the top-level buffer of their // operand. return false; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return true if any uses are detected at 'index', returns false otherwise. + const LogicalBuffer* buffer = + points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (!MayUseBufferInOperand(alias.instruction(), alias.index(), + alias_user, points_to_analysis)) { + continue; + } + // Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return true; + } + } + // Return false: found no uses of 'operand' at 'index' in 'user'. + return false; } return true; } @@ -125,7 +153,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( points_to_analysis.GetBufferAliases(*buffer)) { for (HloInstruction* alias_user : alias.instruction()->users()) { if (!MayUseBufferInOperand(alias.instruction(), alias.index(), - alias_user)) { + alias_user, points_to_analysis)) { continue; } for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { @@ -200,7 +228,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { for (auto user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user)) { + if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user, + points_to_analysis())) { continue; } if (user != b.instruction() && diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 43d2a11b1e..e7aa93f8db 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -488,7 +488,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Runs BufferLiveness on this computation. // Returns whether buffer interference is detected between tuple-shaped // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1) { + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -534,6 +535,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { {dynamic_update_slice, starts, update, gte1}, HloInstruction::FusionKind::kLoop); } + // Create fusion instruction for tuple element 0 (if requested). + if (fuse_gte0) { + computation->CreateFusionInstruction({gte0}, + HloInstruction::FusionKind::kLoop); + } + // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run(module.get(), @@ -562,6 +569,25 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); } +// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases +// 'fusion1') do not overlap in the presence of another fusion instruction +// (which is a user of 'param0' at a different tuple index). +// BufferLiveness should detect no uses of Param0 at index {1} in Fusion0 +// (because Fusion0 only uses Param0 at index {0}). +// +// Param0 +// / \ +// FusionParam <----- Fusion0 Fusion1 ------> FusionParam +// | | | | +// GTE(0) | | GTE(1) Const Const +// | | \ | / +// \ / DynamicUpdateSlice +// Tuple +// +TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { + EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); +} + // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) // do overlap because GTE(1) has two users: // 1) DynamicUpdateSlice at operand 0. |