aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 13:28:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 13:50:45 -0800
commit9490027f2dd44ec367dc0e675921bd2404fb71c0 (patch)
treeee8fc63dccd8aea632b04a6a780dd3bdec60f00c
parentcf3a2356e81b09ee2f3b8048c1b84bb5d12a3a4e (diff)
[XLA] BufferLiveness: eliminate false interference between loop fusion instructions which use the same tuple-shaped operand (but access different tuple elements).
Change: 147636899
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc35
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc28
2 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 60415d5c06..6d8edabf73 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -100,11 +100,39 @@ namespace {
// 'operand'. Returns true otherwise.
// Precondition: 'operand' is an operand of 'user'.
bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index,
- HloInstruction* user) {
+ HloInstruction* user,
+ const TuplePointsToAnalysis& points_to_analysis) {
if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
// GetTupleElement instructions only access the top-level buffer of their
// operand.
return false;
+ } else if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ // Find fusion parameter associated with 'operand'.
+ auto it = std::find_if(
+ user->fused_parameters().begin(), user->fused_parameters().end(),
+ [=](HloInstruction* fused_param) {
+ return user->operand(fused_param->parameter_number()) == operand;
+ });
+ CHECK(it != user->fused_parameters().end());
+ // Iterate through all users of all buffer aliases of the buffer in the
+ // points-to set of fusion parameter at 'index'.
+ // Return true if any uses are detected at 'index', returns false otherwise.
+ const LogicalBuffer* buffer =
+ points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
+ for (const BufferAlias& alias :
+ points_to_analysis.GetBufferAliases(*buffer)) {
+ for (HloInstruction* alias_user : alias.instruction()->users()) {
+ if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
+ alias_user, points_to_analysis)) {
+ continue;
+ }
+ // Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'.
+ return true;
+ }
+ }
+ // Return false: found no uses of 'operand' at 'index' in 'user'.
+ return false;
}
return true;
}
@@ -125,7 +153,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
points_to_analysis.GetBufferAliases(*buffer)) {
for (HloInstruction* alias_user : alias.instruction()->users()) {
if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
- alias_user)) {
+ alias_user, points_to_analysis)) {
continue;
}
for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
@@ -200,7 +228,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
for (auto user : alias.instruction()->users()) {
- if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user)) {
+ if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user,
+ points_to_analysis())) {
continue;
}
if (user != b.instruction() &&
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 43d2a11b1e..e7aa93f8db 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -488,7 +488,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
// Runs BufferLiveness on this computation.
// Returns whether buffer interference is detected between tuple-shaped
// parameter and root instructions at tuple element 1.
- bool Run(const bool update_uses_tuple_element1) {
+ bool Run(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
@@ -534,6 +535,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
{dynamic_update_slice, starts, update, gte1},
HloInstruction::FusionKind::kLoop);
}
+ // Create fusion instruction for tuple element 0 (if requested).
+ if (fuse_gte0) {
+ computation->CreateFusionInstruction({gte0},
+ HloInstruction::FusionKind::kLoop);
+ }
+
// Run BufferLiveness on 'module'.
auto liveness =
BufferLiveness::Run(module.get(),
@@ -562,6 +569,25 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
}
+// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
+// 'fusion1') do not overlap in the presence of another fusion instruction
+// (which is a user of 'param0' at a different tuple index).
+// BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
+// (because Fusion0 only uses Param0 at index {0}).
+//
+// Param0
+// / \
+// FusionParam <----- Fusion0 Fusion1 ------> FusionParam
+// | | | |
+// GTE(0) | | GTE(1) Const Const
+// | | \ | /
+// \ / DynamicUpdateSlice
+// Tuple
+//
+TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
+ EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
+}
+
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
// do overlap because GTE(1) has two users:
// 1) DynamicUpdateSlice at operand 0.