aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-14 17:56:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-14 19:11:27 -0700
commite0d0c676ec111c711099bf89eb51278bc4493678 (patch)
treec15420e4b83c79f620d6b9f9c35bce9d3305e16a /tensorflow/compiler/xla
parent830cde8776d9adb6bdbb2e0b3173d16780d52df7 (diff)
Refactor logic from buffer_liveness to use in HeapSimulator.
Also added some simple tests. Change: 150144113
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc128
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc16
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.cc151
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.h51
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc189
6 files changed, 435 insertions, 132 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 156cb85f66..692d186b14 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -494,6 +494,36 @@ cc_library(
)
cc_library(
+ name = "liveness_util",
+ srcs = ["liveness_util.cc"],
+ hdrs = ["liveness_util.h"],
+ deps = [
+ ":hlo",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "liveness_util_test",
+ srcs = ["liveness_util_test.cc"],
+ deps = [
+ ":hlo",
+ ":liveness_util",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "buffer_liveness",
srcs = [
"buffer_liveness.cc",
@@ -504,6 +534,7 @@ cc_library(
deps = [
":hlo",
":hlo_ordering",
+ ":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -586,6 +617,7 @@ cc_library(
],
deps = [
":hlo",
+ ":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index b5a2936b67..0fe6e37c00 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
-#include <set>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -92,128 +92,6 @@ string BufferLiveness::ToString() const {
return tensorflow::str_util::Join(pieces, "\n");
}
-namespace {
-
-// Returns false if 'user' cannot possibly use the buffer at 'index' in
-// 'operand'. Returns true otherwise.
-// Precondition: 'operand' is an operand of 'user'.
-bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index,
- HloInstruction* user,
- const TuplePointsToAnalysis& points_to_analysis) {
- if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
- // GetTupleElement instructions only access the top-level buffer of their
- // operand.
- return false;
- } else if (user->opcode() == HloOpcode::kFusion &&
- user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
- // Find fusion parameter associated with 'operand'.
- auto it = std::find_if(
- user->fused_parameters().begin(), user->fused_parameters().end(),
- [=](HloInstruction* fused_param) {
- return user->operand(fused_param->parameter_number()) == operand;
- });
- CHECK(it != user->fused_parameters().end());
- // Iterate through all users of all buffer aliases of the buffer in the
- // points-to set of fusion parameter at 'index'.
- // Return true if any uses are detected at 'index', returns false otherwise.
- const LogicalBuffer* buffer =
- points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
- for (const BufferAlias& alias :
- points_to_analysis.GetBufferAliases(*buffer)) {
- for (HloInstruction* alias_user : alias.instruction()->users()) {
- if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
- alias_user, points_to_analysis)) {
- continue;
- }
- // Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'.
- return true;
- }
- }
- // Return false: found no uses of 'operand' at 'index' in 'user'.
- return false;
- }
- return true;
-}
-
-// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
-// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
-// where 'user' is a user of an alias of 'intruction' at 'index', and
-// 'operand_index' is the operand index at which the alias appears in the
-// operand list of 'user'.
-std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
- HloInstruction* instruction, const ShapeIndex& index,
- const TuplePointsToAnalysis& points_to_analysis) {
- std::vector<std::pair<HloInstruction*, int64>> uses;
- const std::vector<const LogicalBuffer*>& points_to =
- points_to_analysis.GetPointsToSet(instruction).element(index);
- for (const LogicalBuffer* buffer : points_to) {
- for (const BufferAlias& alias :
- points_to_analysis.GetBufferAliases(*buffer)) {
- for (HloInstruction* alias_user : alias.instruction()->users()) {
- if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
- alias_user, points_to_analysis)) {
- continue;
- }
- for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
- uses.emplace_back(alias_user, op_idx);
- }
- }
- }
- }
- return uses;
-}
-
-// Returns true if 'user' (at 'user_index') can share a buffer with its operand
-// 'operand' (at 'operand_index').
-// Returns false otherwise.
-// User and operand can share buffers iff both instructions emit the same shape
-// and layout, and 'user' meets one of the following two qualifications:
-// *) Is element-wise.
-// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
-// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
-// at operand 0.
-bool CanShareOperandBufferWithUser(
- HloInstruction* operand, const ShapeIndex& operand_index,
- HloInstruction* user, const ShapeIndex& user_index,
- const TuplePointsToAnalysis& points_to_analysis) {
- Shape operand_subshape =
- ShapeUtil::GetSubshape(operand->shape(), operand_index);
- Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
- // Check that operand and user emit the same shape and layout.
- if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
- return false;
- }
- // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
- // fused root instruction.
- if (user->opcode() == HloOpcode::kFusion &&
- user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- for (auto& fused_param : user->fused_parameters()) {
- // Find fusion parameter associated with 'operand'.
- if (user->operand(fused_param->parameter_number()) != operand) {
- continue;
- }
- // Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
- auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
- fused_param, operand_index, points_to_analysis);
- // Return true iff there is exactly one use of 'operand' at 'index', and
- // this singleton use is the fused root at operand index 0.
- if (fused_param_uses.size() == 1 &&
- fused_param_uses[0].first == user->fused_expression_root() &&
- fused_param_uses[0].second == 0) {
- return true;
- }
- break;
- }
- return false;
- }
- // Check if 'user' is element-wise.
- return user->IsElementwise();
-}
-
-} // anonymous namespace
-
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
@@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
for (auto user : alias.instruction()->users()) {
- if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user,
- points_to_analysis())) {
+ if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user,
+ points_to_analysis())) {
continue;
}
if (user != b.instruction() &&
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 76702f52e0..46c0d8edea 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -26,6 +27,8 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
+namespace {
+
// Returns the set of buffers that may be sources of all operands of the given
// instruction. The returned buffers are guaranteed to have no duplicates, and
// to be sorted in a deterministic order.
@@ -46,6 +49,8 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
return sorted;
}
+} // namespace
+
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm,
@@ -145,13 +150,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// we must be the last user of the buffer.
bool shared = false;
for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) {
- // The operand buffer can be shared if we have the same shape, and we're
- // an elementwise instruction.
- //
- // TODO(b/35903632): Refactor and use the CanShareOperandBufferWithUser
- // logic from buffer_liveness.cc
- if (ShapeUtil::Equal(buffer->shape(), operand_buffer->shape()) &&
- instruction->IsElementwise()) {
+ if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
+ CanShareOperandBufferWithUser(
+ operand_buffer->instruction(), operand_buffer->index(),
+ buffer->instruction(), buffer->index(), points_to_analysis)) {
heap.ShareBuffer(buffer, operand_buffer);
shared = true;
break;
diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc
new file mode 100644
index 0000000000..7d157e8fd5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/liveness_util.cc
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/liveness_util.h"
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
+ HloInstruction* user,
+ const TuplePointsToAnalysis& points_to_analysis) {
+ CHECK(user->IsUserOf(operand))
+ << "user: " << user->ToString() << " operand: " << operand->ToString();
+ if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
+ // GetTupleElement instructions only access the top-level buffer of their
+ // operand.
+ return true;
+ } else if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ // Find fusion parameter associated with 'operand'.
+ auto it = std::find_if(
+ user->fused_parameters().begin(), user->fused_parameters().end(),
+ [=](HloInstruction* fused_param) {
+ return user->operand(fused_param->parameter_number()) == operand;
+ });
+ CHECK(it != user->fused_parameters().end());
+ // Iterate through all users of all buffer aliases of the buffer in the
+ // points-to set of fusion parameter at 'index'.
+ // Return false if any uses are detected at 'index', returns true otherwise.
+ const LogicalBuffer* buffer =
+ points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
+ for (const BufferAlias& alias :
+ points_to_analysis.GetBufferAliases(*buffer)) {
+ for (HloInstruction* alias_user : alias.instruction()->users()) {
+ if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
+ alias_user, points_to_analysis)) {
+ continue;
+ }
+ // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
+ return false;
+ }
+ }
+ // Return true: found no uses of 'operand' at 'index' in 'user'.
+ return true;
+ }
+ return false;
+}
+
+namespace {
+
+// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
+// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
+// where 'user' is a user of an alias of 'intruction' at 'index', and
+// 'operand_index' is the operand index at which the alias appears in the
+// operand list of 'user'.
+std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
+ HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis) {
+ std::vector<std::pair<HloInstruction*, int64>> uses;
+ const std::vector<const LogicalBuffer*>& points_to =
+ points_to_analysis.GetPointsToSet(instruction).element(index);
+ for (const LogicalBuffer* buffer : points_to) {
+ for (const BufferAlias& alias :
+ points_to_analysis.GetBufferAliases(*buffer)) {
+ for (HloInstruction* alias_user : alias.instruction()->users()) {
+ if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
+ alias_user, points_to_analysis)) {
+ continue;
+ }
+ for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
+ uses.emplace_back(alias_user, op_idx);
+ }
+ }
+ }
+ }
+ return uses;
+}
+
+} // namespace
+
+// User and operand can share buffers iff both instructions emit the same shape
+// and layout, and 'user' meets one of the following two qualifications:
+// *) Is element-wise.
+// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
+// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
+// at operand 0.
+bool CanShareOperandBufferWithUser(
+ HloInstruction* operand, const ShapeIndex& operand_index,
+ HloInstruction* user, const ShapeIndex& user_index,
+ const TuplePointsToAnalysis& points_to_analysis) {
+ CHECK(user->IsUserOf(operand))
+ << "user: " << user->ToString() << " operand: " << operand->ToString();
+ Shape operand_subshape =
+ ShapeUtil::GetSubshape(operand->shape(), operand_index);
+ Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
+ // Check that operand and user emit the same shape and layout.
+ if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
+ return false;
+ }
+ // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
+ // fused root instruction.
+ if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ for (auto& fused_param : user->fused_parameters()) {
+ // Find fusion parameter associated with 'operand'.
+ if (user->operand(fused_param->parameter_number()) != operand) {
+ continue;
+ }
+ // Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
+ auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
+ fused_param, operand_index, points_to_analysis);
+ // Return true iff there is exactly one use of 'operand' at 'index', and
+ // this singleton use is the fused root at operand index 0.
+ if (fused_param_uses.size() == 1 &&
+ fused_param_uses[0].first == user->fused_expression_root() &&
+ fused_param_uses[0].second == 0) {
+ return true;
+ }
+ break;
+ }
+ return false;
+ }
+ // Check if 'user' is element-wise.
+ return user->IsElementwise();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h
new file mode 100644
index 0000000000..410a7b1b51
--- /dev/null
+++ b/tensorflow/compiler/xla/service/liveness_util.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A collection of utilities on the HLO graph.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
+
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+// Returns true if 'user' cannot possibly use the buffer at 'index' in
+// 'operand'. Returns false otherwise.
+//
+// REQUIRES: 'operand' is an operand of 'user'.
+bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
+ HloInstruction* user,
+ const TuplePointsToAnalysis& points_to_analysis);
+
+// Returns true if 'user' (at 'user_index') can share a buffer with its operand
+// 'operand' (at 'operand_index').
+// Returns false otherwise.
+//
+// REQUIRES: 'operand' is an operand of 'user'.
+bool CanShareOperandBufferWithUser(
+ HloInstruction* operand, const ShapeIndex& operand_index,
+ HloInstruction* user, const ShapeIndex& user_index,
+ const TuplePointsToAnalysis& points_to_analysis);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
new file mode 100644
index 0000000000..2ff71d6f3c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/liveness_util.h"
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+class PointsToAnalysisTestBase : public HloTestBase {
+ protected:
+ void BuildModule(std::unique_ptr<HloComputation> computation) {
+ module_ = MakeUnique<HloModule>(TestName());
+ computation_ = module_->AddEntryComputation(std::move(computation));
+ }
+
+ void RunAnalysis() {
+ CHECK_NOTNULL(module_.get());
+ points_to_analysis_ =
+ TuplePointsToAnalysis::Run(module_.get(),
+ /*include_loop_fusion_instructions=*/true)
+ .ConsumeValueOrDie();
+ }
+
+ void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
+ BuildModule(std::move(computation));
+ RunAnalysis();
+ }
+
+ std::unique_ptr<HloModule> module_;
+ HloComputation* computation_ = nullptr;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
+};
+
+class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
+
+TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // GetTupleElement instructions only access the top-level buffer of their
+ // operand.
+ EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_));
+ EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
+ EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
+ EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
+}
+
+TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction never uses tuple element 0, but does use element 1.
+ EXPECT_TRUE(
+ DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
+ EXPECT_FALSE(
+ DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
+}
+
+class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
+
+TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape shape = ShapeUtil::MakeShape(F32, {8});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+ auto result = builder.AddInstruction(
+ HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
+ *points_to_analysis_));
+ EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
+ *points_to_analysis_));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction can share with tuple element 1.
+ EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
+ *points_to_analysis_));
+ EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
+ *points_to_analysis_));
+}
+
+} // namespace
+} // namespace xla