diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment_test.cc | 1051 |
1 files changed, 1051 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc new file mode 100644 index 0000000000..56138a7ee6 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -0,0 +1,1051 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +namespace { + +// DFS visitor that collects the instructions referenced by a computation +// without descending into nested computations, i.e., only from the operands. +class InstructionListVisitor : public DfsHloVisitorWithDefault { + public: + explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {} + + Status DefaultAction(HloInstruction* hlo) override { + // For each instruction, just push it on the list after walking the + // operands. + instructions_.push_back(hlo); + VLOG(0) << "List instruction " << hlo->ToString(); + return Status::OK(); + } + + std::vector<const HloInstruction*> GetInstructions() { return instructions_; } + + private: + // The instruction root of the computation. + const HloInstruction* root_; + + // The full set of instructions found (may be duplicates, e.g., kParameter). + std::vector<const HloInstruction*> instructions_; + + TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor); +}; + +const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) { + InstructionListVisitor main_list(root); + TF_CHECK_OK(root->Accept(&main_list)); + return main_list.GetInstructions(); +} + +class BufferAssignmentTest : public HloTestBase { + protected: + BufferAssignmentTest() : computation_tracker_() {} + ~BufferAssignmentTest() override {} + + // Builds an x+1.0 computation to use in a Map. + std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); + return builder.Build(); + } + + // Builds a simple compare-to-limit (x < 4) computation for a While. + // + // condition: + // const4[s32] -----------------------------------\ + // \ + // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than + // + std::unique_ptr<HloComputation> BuildWhileConditionComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto const4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); + auto index = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); + return builder.Build(); + } + + // Builds a simple body computation for a While. + // + // body: + // constv[f32[4]] --------------------------------------\ + // \ + // /--- get-tuple-elementv[1] --- addv ---\ + // param[(s32,f32[4])] ---| tuple + // \--- get-tuple-elementc[0] --- addc ---/ + // / + // const1[s32] -----------------------------------------/ + // + std::unique_ptr<HloComputation> BuildWhileBodyComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1))); + auto constv = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); + auto indexc = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const1->shape(), param, 0)); + auto addc = builder.AddInstruction(HloInstruction::CreateBinary( + indexc->shape(), HloOpcode::kAdd, indexc, const1)); + auto indexv = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(constv->shape(), param, 1)); + auto addv = builder.AddInstruction(HloInstruction::CreateBinary( + constv->shape(), HloOpcode::kAdd, indexv, constv)); + builder.AddInstruction(HloInstruction::CreateTuple({addc, addv})); + return builder.Build(); + } + + // Verifies that the given instruction hlo has a valid input buffer assigned, + // i.e., the parameter number matches the op's. + const BufferAllocation& GetAssignedInputAllocation( + const BufferAssignment& buffers, HloInstruction* hlo) { + LOG(INFO) << "Checking input: " << hlo->ToString(); + const BufferAllocation& buffer = + *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number()); + return buffer; + } + + // Verifies that the given instruction hlo has a valid output buffer + // assigned, and returns it. + const BufferAllocation& GetAssignedOutputAllocation( + const BufferAssignment& buffers, HloInstruction* hlo) { + LOG(INFO) << "Checking output: " << hlo->ToString(); + const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo); + return buffer; + } + + // Returns the allocation for the given instruction. + const BufferAllocation& GetAllocation(const BufferAssignment& buffers, + const HloInstruction* hlo, + const ShapeIndex& index) { + return *buffers.GetUniqueAllocation(hlo, index).ConsumeValueOrDie(); + } + const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers, + const HloInstruction* hlo) { + return *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + } + + // Verifies that all instructions in the given instruction list except + // kConstant have assigned buffers, and returns their total size. If min_index + // and max_index are not nullptr, the minimum and maximum buffer indices in + // the assignment are written into them. + int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions, + const BufferAssignment& buffers) { + // Verifies all instructions have buffers, and gets the index ranges. + for (const HloInstruction* hlo : instructions) { + if (!buffers.HasTopLevelAllocation(hlo)) { + // If `hlo` has no assigned buffer, it is either a constant or a nested + // parameter. + EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() || + HloOpcode::kParameter == hlo->opcode()); + continue; + } + } + + // Gets the total size of all buffers assigned. + int64 total_size = 0; + for (auto& allocation : buffers.Allocations()) { + total_size += allocation.size(); + } + return total_size; + } + + // Returns true if the buffers assigned to instructions in "a" are distinct + // from the buffers assigned to those in "b" (ie, intersection is empty). + bool BuffersDistinct(const std::vector<const HloInstruction*>& a, + const std::vector<const HloInstruction*>& b, + const BufferAssignment& assignment) { + std::set<BufferAllocation::Index> a_buffers; + for (const HloInstruction* instruction : a) { + if (assignment.HasTopLevelAllocation(instruction)) { + a_buffers.insert(assignment.GetUniqueTopLevelAllocation(instruction) + .ConsumeValueOrDie() + ->index()); + } + } + + for (const HloInstruction* instruction : b) { + if (assignment.HasTopLevelAllocation(instruction)) { + if (a_buffers.count(assignment.GetUniqueTopLevelAllocation(instruction) + .ConsumeValueOrDie() + ->index())) { + return false; + } + } + } + return true; + } + + // Computation tracker for nested computations. + ComputationTracker computation_tracker_; + + // Shapes for use in the examples. + Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); + Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); + Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); + Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10}); + Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100}); + Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10}); + Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_}); + Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); +}; + +namespace { +std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module) { + return BufferAssigner::Run(module, MakeUnique<DependencyHloOrdering>(module), + /*pointer_size=*/sizeof(void*)) + .ConsumeValueOrDie(); +} +} + +// Tests a computation consisting of a single scalar constant node. +TEST_F(BufferAssignmentTest, ScalarConstant) { + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // Check that the constant does not have a buffer assigned. + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); +} + +TEST_F(BufferAssignmentTest, BufferForConst) { + // Addition of two vector constants: checks that internal constant nodes have + // no buffers assigned, and their consumer has a buffer. + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f}))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // The two constant nodes have no buffers assigned. + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); + // The add node has an output buffer. + GetAssignedOutputAllocation(*buffers, add); +} + +TEST_F(BufferAssignmentTest, BufferForOutputConst) { + // This computation copies a constant to output. + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // The copy node now has an output buffer. + GetAssignedOutputAllocation(*buffers, copy); +} + +TEST_F(BufferAssignmentTest, Basic) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + + // The mul node has a valid buffer assigned, doesn't share with input. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + + // The add node can reuse the mul node's buffer. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_EQ(add_buffer.index(), add_buffer.index()); + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); +} + +TEST_F(BufferAssignmentTest, MultipleUsersForNode) { + // This is similar to the Basic test, with the difference that (sub) is + // another user of (mul)'s result, so (mul)'s buffer cannot be reused for + // (add)'s output. + // + // paramscalar -------\ /-----------\ + // \ / \ + // param0[100] ------- (mul) -- (add) -- (sub) + // / + // param1[100] ----------------/ + // + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_index.index()); + EXPECT_NE(param0_buffer.index(), param1_index.index()); + + // The mul node had a buffer allocated. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + + // Now the add node can't reuse the mul node's buffer. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_NE(add_buffer.index(), mul_buffer.index()); + + // Log size information for inspection. + const std::vector<const HloInstruction*> level0 = GetInstructions(sub); + int64 size0 = ValidateBuffers(level0, *buffers); + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() << " instructions; " + << "total buffer size " << size0; +} + +TEST_F(BufferAssignmentTest, TrivialMap) { + // This tests a trivial x+1 map as the only operation. + // + // param0[100x10] ---> (map x+1) + // + // Builds the map function. + auto module = MakeUnique<HloModule>(TestName()); + auto map_computation = + module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + auto inner_last = map_computation->root_instruction(); + + // Creates the main kernel and verifies instruction counts. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32a100x10_, "")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); + const std::vector<const HloInstruction*> level0 = GetInstructions(map); + EXPECT_EQ(2, level0.size()) << "Invalid main kernel size"; + const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last); + EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; + + module->AddEntryComputation(builder.Build()); + + // Assigns buffers and fetches sizes. + auto buffers = RunBufferAssignment(module.get()); + int64 size0 = ValidateBuffers(level0, *buffers); + int64 size1 = ValidateBuffers(level1, *buffers); + + // Both algorithms assign the map's buffer before processing the embedded + // computation, so we can verify that the buffers aren't shared between them + // by checking: + EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers)) + << "Reuse between main kernel and embedded mapping."; + + // An input buffer was assigned for the parameter. + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + + // An output buffer was assigned for the map. + BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map); + EXPECT_NE(param0_buffer.index(), map_buffer.index()); + + // The final computation node of the map is an add of an f32 parm and a + // constant. + EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode()); + const BufferAllocation& inner_add_buffer = + GetTopLevelAllocation(*buffers, inner_last); + EXPECT_NE(inner_add_buffer.index(), map_buffer.index()); + + // Log size information for inspection. + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() + level1.size() << " instructions; " + << "total buffer size " << size0 + size1; +} + +TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { + // Make sure that the input buffer of a reduce cannot be reused for its + // output. (Reuse is not safe in the general case, as it reshapes and some + // out-of-order reductions could overwrite an element before a use.) + // + // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) + auto module = MakeUnique<HloModule>(TestName()); + auto reduce_computation = + module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32a100x10_, "")); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + /*shape=*/f32vec10_, + /*operand=*/exp2, + /*init_value=*/const0, + /*dimensions_to_reduce=*/{0}, reduce_computation)); + auto exp3 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce)); + + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + const std::vector<const HloInstruction*> instrs = GetInstructions(exp3); + ValidateBuffers(instrs, *buffers); + + const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1); + const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2); + const BufferAllocation& reduce_buffer = + GetTopLevelAllocation(*buffers, reduce); + + // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity + // checking. + EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index()); + + // The buffer of exp2 cannot be used for reduce, even though it's the only + // operand. + EXPECT_NE(exp2_buffer.index(), reduce_buffer.index()); +} + +TEST_F(BufferAssignmentTest, ExampleWhile) { + // This tests a While loop example from the ir_semantics document. + // + // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation. + // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation. + // + // const3[s32] -------\ + // const4[f32[4]] --- tuple --- while[condition, body] + // + // Builds the nested condition and body. + auto module = MakeUnique<HloModule>(TestName()); + auto condition_computation = + module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); + auto body_computation = + module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update")); + + // Creates the main kernel and verifies instruction counts. + auto builder = HloComputation::Builder(TestName()); + auto const3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0))); + auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); + auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( + t_s32_f32v4_, condition_computation, body_computation, tuple)); + + const std::vector<const HloInstruction*> level0 = GetInstructions(while_op); + EXPECT_EQ(4, level0.size()) << "Invalid while kernel size"; + const std::vector<const HloInstruction*> levelc = + GetInstructions(condition_computation->root_instruction()); + EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size"; + const std::vector<const HloInstruction*> levelb = + GetInstructions(body_computation->root_instruction()); + EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; + + module->AddEntryComputation(builder.Build()); + + // Assigns buffers and fetches sizes. + auto buffers = RunBufferAssignment(module.get()); + int64 size0 = ValidateBuffers(level0, *buffers); + int64 sizec = ValidateBuffers(levelc, *buffers); + int64 sizeb = ValidateBuffers(levelb, *buffers); + + // BufferAssignment will assign a single allocation for the following + // instructions: while, while.cond.param, while.body.param, while.body.result. + EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers)) + << "Should be reuse between main kernel and embedded condition."; + EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers)) + << "Should be reuse between embedded condition and body."; + // Expect buffer reuse between main kernel and body computation. + EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers)) + << "Should be reuse between main kernel and embedded body."; + + // The final computation node of the while body is a tuple of s32 and + // f32[4] adds. + HloInstruction* body_root = body_computation->root_instruction(); + EXPECT_EQ(HloOpcode::kTuple, body_root->opcode()); + + // Check that buffer for each subshape of 'while_op' shares allocation with + // corresponding buffer from while body computation at same index. + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + while_op->shape(), + [this, &buffers, while_op, body_root](const Shape& /*subshape*/, + const ShapeIndex& index) { + auto while_op_allocation = GetAllocation(*buffers, while_op, index); + auto body_root_allocation = GetAllocation(*buffers, body_root, index); + EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index()); + return Status::OK(); + })); + + // Log size information for inspection. + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() + levelc.size() + levelb.size() + << " instructions; total buffer size " << size0 + sizec + sizeb; +} + +TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { + // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "")); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // tanh and exp2 can reuse exp1's buffer + EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); + auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh)); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2)); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg)); +} + +TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { + // This computation is a chain of operations which decreases in buffer size + // (via slice) then increases in size (via broadcast): + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // negate and broadcast should share a buffer. + EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); + auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); + EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); + + // Slice should have its own buffer. + EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { + // This computation is identical to that in ReuseNonOperandBuffer, but the + // negate value is live until the end of the computation (due to it being an + // operand of the output tuple) preventing reuse. + // + // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple) + // \-----------------------------------/ + // + // The negate should not share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { + // This computation is identical to that in ReuseNonOperandBuffer, but the + // negate value is placed into a tuple which lives to the end of the + // computation. This extends the live range of negate's buffer preventing + // reuse due to buffer aliasing. + // + // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple) + // \-----------------------------------/ + // + // The negate should not share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate})); + auto tuple_element = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { + // This computation is very similar to ReuseNonOperandBuffer except the + // broadcast has a smaller output than the negate. This should block reuse of + // negate's buffer by broadcast because the output buffer(s) of a computation + // should be exactly sized for the value. + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { + // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast + // output is exactly the same size as the negate (rather than being + // smaller). This enables reuse of negate's buffer by the broadcast because + // the output buffer will be sized exactly to its value. + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // negate and broadcast should share a buffer. + EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); + auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); + EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); + + // Slice should have its own buffer. + EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { + // This computation is very similar to ReuseNonOperandBuffer except the + // broadcast has a smaller output than the negate, and the broadcast is + // contained in the computation output as a tuple element. This should block + // reuse of the negate's buffer by the broadcast because the output buffer(s) + // of a computation should be exactly sized for the value. This includes those + // buffers aliased in the output (eg, contained as tuple elements). + // + // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); + builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { + // Verify that buffers for embedded computations are properly marked as + // thread-local and that embedded parameters are not marked as + // is_entry_computation_parameter. + auto module = MakeUnique<HloModule>(TestName()); + auto vec_shape = ShapeUtil::MakeShape(F32, {42}); + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + + // Create a scalar computation to use in a map. + auto map_builder = HloComputation::Builder(TestName() + "_map"); + auto map_param = map_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "map_param")); + auto map_root = map_builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param)); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + + // Create a vector computation to use in a kCall. + auto call_builder = HloComputation::Builder(TestName() + "_call"); + auto call_param = call_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_shape, "vec_param")); + auto call_root = call_builder.AddInstruction( + HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param)); + auto call_computation = module->AddEmbeddedComputation(call_builder.Build()); + + // Create entry computation which kCalls call_computation and then calls map + // with map_computation on the result. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_shape, "param")); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vec_shape, {param}, call_computation)); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(vec_shape, {call}, map_computation)); + module->AddEntryComputation(builder.Build()); + + auto assignment = RunBufferAssignment(module.get()); + + // Allocations for the map computation should be thread-local and not + // live-out. + auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param); + EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_param_alloc.maybe_live_out()); + EXPECT_TRUE(map_param_alloc.is_thread_local()); + + auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root); + EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_root_alloc.maybe_live_out()); + EXPECT_TRUE(map_root_alloc.is_thread_local()); + + // Allocations for the call computation should not be thread-local and not + // live-out. + auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); + EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(call_param_alloc.maybe_live_out()); + EXPECT_FALSE(call_param_alloc.is_thread_local()); + + auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); + EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(call_root_alloc.maybe_live_out()); + EXPECT_FALSE(call_root_alloc.is_thread_local()); + + // Entry computation allocations can be marked liveout and + // is_entry_computation_parameter. + auto& param_alloc = GetTopLevelAllocation(*assignment, param); + EXPECT_TRUE(param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(param_alloc.maybe_live_out()); + EXPECT_FALSE(param_alloc.is_thread_local()); + + auto& map_alloc = GetTopLevelAllocation(*assignment, map); + EXPECT_FALSE(map_alloc.is_entry_computation_parameter()); + EXPECT_TRUE(map_alloc.maybe_live_out()); + EXPECT_FALSE(map_alloc.is_thread_local()); +} + +TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { + // Test a computation that returns a tuple parameter. + auto builder = HloComputation::Builder(TestName()); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {42})}), + "param0")); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // There should be four allocations: one for vector of pointers, and one for + // each tuple element. + EXPECT_EQ(4, assignment->Allocations().size()); + + // Verify each buffer allocation is marked as an entry computation parameter + // and is liveout. + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + tuple_param->shape(), + [this, &assignment, tuple_param](const Shape& /*subshape*/, + const ShapeIndex& index) { + auto allocation = GetAllocation(*assignment, tuple_param, index); + EXPECT_TRUE(allocation.is_entry_computation_parameter()); + EXPECT_EQ(0, allocation.parameter_number()); + EXPECT_TRUE(allocation.maybe_live_out()); + return Status::OK(); + })); +} + +TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { + // Test a computation which returns a GetElementTuple of a nested tuple + // parameter. + auto builder = HloComputation::Builder(TestName()); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), + ShapeUtil::MakeShape(S32, {101})})}), + "param0")); + auto tuple_element = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Only some of the elements of the input param are liveout. + EXPECT_FALSE( + GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out()); + // Tuple element at index={1} is live out because GetTupleElement({1}) + // forwards a pointer to this allocation (instead of defining its own buffer). + EXPECT_TRUE( + GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out()); + EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}) + .maybe_live_out()); + EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}) + .maybe_live_out()); + + // The GetTupleElement output is liveout. + EXPECT_TRUE( + GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out()); + + // Verify that the GetTupleElement allocations of its elements match the + // corresponding tuple parameter allocations because they alias. + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}), + GetAllocation(*assignment, tuple_element, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}), + GetAllocation(*assignment, tuple_element, /*index=*/{1})); + + // GetTupleElement forwards a pointer to its underlying buffer, so verify + // that it has the same allocation than the corresponding parameter element. + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}), + GetTopLevelAllocation(*assignment, tuple_element)); +} + +// TODO(b/32248867): Enable when buffer assignment gives allocations to +// constants. +TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { + // Test that a tuple constant which is forwarded to the computation output is + // properly handled. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(), + LiteralUtil::CreateR0<int64>(1).get()}))); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_EQ(3, assignment->Allocations().size()); +} + +TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { + // Test a computation which returns a tuple custom call value. + auto builder = HloComputation::Builder(TestName()); + auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(S32, {101})}), + /*operands=*/{}, /*custom_call_target=*/"foo_function")); + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_EQ(3, assignment->Allocations().size()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out()); +} + +TEST_F(BufferAssignmentTest, BitcastAsOutput) { + // Test a computation which returns a bitcast value. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "param")); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Bitcast should get the same allocation as the param. + EXPECT_EQ(1, assignment->Allocations().size()); + EXPECT_EQ(GetTopLevelAllocation(*assignment, param), + GetTopLevelAllocation(*assignment, bitcast)); +} + +TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { + // Test a computation with an output that has an ambiguous points-to set. This + // is constructed using a select among tuple shapes. + auto builder = HloComputation::Builder(TestName()); + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})}); + + auto tuple_param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param0")); + auto tuple_param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, tuple_shape, "param1")); + auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {}), "param1")); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); + + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Select shallow copies one of its operands so it defines its own top-level + // buffer and receives its own allocation. + auto select_alloc = GetTopLevelAllocation(*assignment, select); + EXPECT_EQ(1, select_alloc.assigned_buffers().size()); + EXPECT_EQ(select, select_alloc.assigned_buffers()[0]->instruction()); + + // The buffer for the tuple element of the select is forwarded from one its + // operands which cannot be determined statically. Therefore its allocation + // should include the allocations of both of the elements in the parameters. + auto element_allocations = assignment->GetAllocations(select, /*index=*/{0}); + EXPECT_EQ(2, element_allocations.size()); + EXPECT_MATCH(testing::SetToVec<BufferAllocation>(element_allocations), + testing::UnorderedMatcher<BufferAllocation>( + *assignment->GetUniqueAllocation(tuple_param0, /*index=*/{0}) + .ConsumeValueOrDie(), + *assignment->GetUniqueAllocation(tuple_param1, /*index=*/{0}) + .ConsumeValueOrDie())); +} + +} // namespace + +} // namespace xla |