diff options
Diffstat (limited to 'tensorflow')
15 files changed, 152 insertions, 102 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 09e7e87918..a0004281cb 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1850,18 +1850,19 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); - // Send instruction produces a tuple of {aliased operand, U32 context}. + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *send_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); send_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp send, AddInstruction(std::move(send_instr), HloOpcode::kSend, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -1879,19 +1880,32 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); - // Recv instruction produces a tuple of {receive buffer, U32 context}. + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); recv_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = shape; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); - return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, - {recv}); + TF_ASSIGN_OR_RETURN(XlaOp recv_done, + AddInstruction(std::move(recv_done_instr), + HloOpcode::kRecvDone, {recv})); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv_done}); }); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 34b18b0e21..e36bef60a3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -284,9 +284,8 @@ void HloComputation::set_root_instruction( if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) - << new_root_instruction->shape().ShortDebugString() - << " is incompatible with " - << root_instruction_->shape().ShortDebugString(); + << new_root_instruction->shape() << " is incompatible with " + << root_instruction_->shape(); } bool root_found = false; for (auto& instruction : instructions_) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 8a4a9b5986..ebed2cfc59 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -398,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); bool changed = false; - // RecvDone forwards the operand value at {0} to the output. + // RecvDone forwards the operand value at {0} to element {0} of its output. for (auto& pair : GetInstructionValueSet(recv_done)) { ShapeIndex& index = pair.first; HloValueSet& value_set = pair.second; - ShapeIndex operand_index = {0}; - for (int64 i : index) { - operand_index.push_back(i); + if (index.empty() || index[0] != 0) { + continue; } const HloValueSet& operand_value_set = - GetValueSet(recv_done->operand(0), operand_index); + GetValueSet(recv_done->operand(0), index); if (value_set != operand_value_set) { value_set = operand_value_set; changed = true; @@ -857,14 +856,18 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_top_level_only(); break; case HloOpcode::kRecvDone: - // RecvDone aliases its input tuple element {0}, therefore does not - // define any values. + // RecvDone produces a two-element tuple. Element zero aliases its + // input tuple element {0}; element one is a token. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); break; case HloOpcode::kSend: - // Send produces a tuple of {aliased operand, U32 context}, therefore - // only defines the top-level tuple and the tuple element at {1}. + // Send produces a tuple of {aliased operand, U32 context, token}, + // therefore only defines the top-level tuple and the tuple elements + // at {1} and {2}. define_value_at(/*index=*/{}); define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); break; default: define_all_values(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 70254e2c1a..343f5e7b39 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1167,20 +1167,21 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 5); + EXPECT_EQ(analysis.values().size(), 6); EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(param))); } TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { - // Test that a RecvDone forwards its operand tuple element at {0} to the - // output. + // Test that a RecvDone forwards its operand tuple element at {0} to element + // {0} of the output. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction( @@ -1191,13 +1192,16 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 4); + EXPECT_EQ(analysis.values().size(), 7); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); - EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); - EXPECT_THAT(HloValuesAt(recv_done), + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1})); + EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); EXPECT_TRUE( analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index c1412f7c68..3859e4cae6 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -202,12 +202,13 @@ ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 token = token[] after-all() - b = (f32[4], u32[]) send(a, token), channel_id=1, sharding={maximal device=0} - c = () send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[]) recv(token), channel_id=2, sharding={maximal device=0} - e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} - f = f32[4] add(a, e) - g = f32[4] subtract(a, e) + b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + c = token[] send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} + e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} + f = f32[4] add(a, e_element) + g = f32[4] subtract(a, e_element) ROOT h = (f32[4], f32[4]) tuple(f, g) } )"; @@ -220,7 +221,7 @@ ENTRY entry { EXPECT_TRUE(isolator_changed); EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); EXPECT_FALSE(HasDomainEdge(module, "c", "b")); EXPECT_FALSE(HasDomainEdge(module, "e", "d")); @@ -231,7 +232,7 @@ ENTRY entry { EXPECT_TRUE(remover_changed); EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -240,11 +241,12 @@ HloModule Module ENTRY entry { token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=-1} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} - c = f32[4] add(b, b), sharding={maximal device=-1} - d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=-1} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} + c = f32[4] add(b_element, b_element), sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -262,11 +264,12 @@ HloModule Module ENTRY entry { token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=0} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} - c = f32[4] add(b, b) - d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=0} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} + c = f32[4] add(b_element, b_element) + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5533da6eb7..98dac792fa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1623,8 +1623,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) - << old_operand->shape().ShortDebugString() << " is not compatible with " - << new_operand->shape().ShortDebugString(); + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index dcc1e3c8af..7052e236cd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -207,8 +207,9 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand, HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kSend, - ShapeUtil::MakeTupleShape( - {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), channel_id) { AppendOperand(operand); AppendOperand(token); @@ -224,7 +225,7 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl( } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) - : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), CHECK_NOTNULL(operand)->channel_id()) { AppendOperand(operand); } @@ -244,7 +245,8 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape, HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kRecv, - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), channel_id) { AppendOperand(token); } @@ -261,7 +263,9 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl( HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) : HloSendRecvInstruction( HloOpcode::kRecvDone, - ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), CHECK_NOTNULL(operand)->channel_id()) { AppendOperand(operand); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index bf33640db1..6bcd7b042d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -382,7 +382,8 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { // Check if the shapes match for each channel. for (const Channel& channel : channels_) { const Shape& send_shape = channel.send->operand(0)->shape(); - const Shape& recv_shape = channel.recv_done->shape(); + const Shape& recv_shape = + ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0); if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index f40cd60907..88f3309baa 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -277,13 +277,13 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { "SendRecv", R"(HloModule TwoSendRecvBothWayRecvFist_module -ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { %token = token[] after-all() - %recv = (f32[], u32[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} - ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } )" @@ -1223,11 +1223,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %token = token[] after-all() - %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1240,11 +1240,11 @@ TEST_F(HloParserTest, MissingAttribute) { ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %token = token[] after-all() - %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[]) send(f32[] %constant, token[] %token) - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1257,11 +1257,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %token = token[] after-all() - %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 765245096b..2e6ea14426 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -346,9 +346,10 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape(send, + ShapeUtil::MakeTupleShape({send->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { @@ -357,7 +358,7 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { TF_RET_CHECK(send->opcode() == HloOpcode::kSend); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); + return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { @@ -366,9 +367,10 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape( + recv, ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv_done->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { @@ -376,7 +378,9 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { const HloInstruction* recv = recv_done->operand(0); TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); + return CheckShape(recv_done, + ShapeUtil::MakeTupleShape({recv->shape().tuple_shapes(0), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleBatchNormTraining( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 36fdfa868d..fedc83c8f8 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1630,7 +1630,8 @@ Status LayoutAssignment::ConstrainChannelLayouts( for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kRecvDone) { const Layout* layout = channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->shape().layout()); + instruction->channel_id(), + ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); TF_RET_CHECK(layout == nullptr) << instruction->ToString() << " cannot constrain layout as it was set to " @@ -1647,7 +1648,7 @@ Status LayoutAssignment::ConstrainChannelLayouts( instruction->channel_id(), operand->shape().layout()); if (layout != nullptr) { // We found an already constrained layout which does not match the one - // the kSend wants to impose. Eitehr add a new kCopy, or use the + // the kSend wants to impose. Either add a new kCopy, or use the // existing one to marshal the correct shape. Shape shape = operand->shape(); *shape.mutable_layout() = *layout; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 4cd584bf8b..a673901c75 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -830,12 +830,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 token = token[] after-all() - recv = (f32[2,2], u32[]) recv(token), channel_id=1, sharding={maximal device=1} - ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, + recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} - send = (f32[2,2], u32[]) send(gte, token), channel_id=1, + ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 + send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, sharding={maximal device=0} - send-done = () send-done(send), channel_id=1, sharding={maximal device=0} + send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } )"; @@ -854,7 +855,7 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { AssignLayouts(module.get(), &computation_layout, &channel_constraints); EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::GetSubshape( FindInstruction(module.get(), "send")->shape(), {0}), diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index f410921b4b..5da26d832b 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -131,18 +131,23 @@ Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { - // RecvDone doesn't create a new buffer but rather aliases its input (Recv) - // tuple element at {0} to its output. +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone produces a two-element tuple containing the data value (which + // aliases part of its operand) and a token. Only the tuple index table and + // the token are defined by the RecvDone. + NewLogicalBuffer(recv_done, /*index=*/{}); + NewLogicalBuffer(recv_done, /*index=*/{1}); return Status::OK(); } Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { - // Send creates new buffers for the top-level tuple and the context (tuple - // element at {1}). Tuple element at {0} is an alias of the Send operand, so - // we don't need to create a new Logical Buffer for that. + // Send creates new buffers for the top-level tuple, the context (tuple + // element at {1}), and the token (tuple element at {2}). Tuple element at {0} + // is an alias of the Send operand, so we don't need to create a new Logical + // Buffer for that. NewLogicalBuffer(send, /*index=*/{}); NewLogicalBuffer(send, /*index=*/{1}); + NewLogicalBuffer(send, /*index=*/{2}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index d1e1744647..a1aa875009 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -292,22 +292,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { } Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { - // RecvDone aliases its input (Recv) tuple element {0} to its output. + // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its + // output. The other indices ({} and {1}) define their own buffers. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}), + /*index=*/{}); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}), + /*index=*/{1}); + const PointsToSet& operand_points_to_set = GetPointsToSet(recv_done->operand(0)); - // Recursively copy the points to set of the operand tuple {0}. + // Recursively copy the points to set of the operand tuple {0} to the output + // element {0}. points_to_set.ForEachMutableElement( [this, &points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { - ShapeIndex src_index({0}); - for (auto element : index) { - src_index.push_back(element); + if (index.empty() || index[0] != 0) { + return; } - *buffers = operand_points_to_set.element(src_index); - for (auto& tuple_source : - operand_points_to_set.tuple_sources(src_index)) { + *buffers = operand_points_to_set.element(index); + for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) { points_to_set.add_tuple_source(index, tuple_source); } }); @@ -315,7 +322,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { } Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { - // Send creates a tuple of {aliased operand, U32 context}. + // Send creates a tuple of {aliased operand, U32 context, token}. PointsToSet& points_to_set = CreateEmptyPointsToSet(send); // Creates the points to set for the tuple and its element at {1}. @@ -328,6 +335,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { context_buffer->push_back( &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + auto token_buffer = points_to_set.mutable_element(ShapeIndex({2})); + token_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2}))); + // Recursively copy the points to set of the operand to output tuple {0}. const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); operand_points_to_set.ForEachElement( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a8f885fd86..1e7d058eb6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -357,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); - ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}}); } TEST_F(TuplePointsToAnalysisTest, TupleSelect) { |