diff options
author | 2018-07-02 12:22:15 -0700 | |
---|---|---|
committer | 2018-07-02 12:27:03 -0700 | |
commit | 20e27ad56b95e19ebeb23e34db1aff22e0bd473e (patch) | |
tree | 0ff0a458b28a527acbc82b6ae1f5a36aeb96d1c8 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 0967cbb9a34b69ec14238802460971abbec9cbb4 (diff) |
Change Send and Recv HLOs to take a token operand.
Send and Recv HLOs now have an additional required operand which must be token-shaped. XLA client interface for these operations is unchanged and will be updated in follow up CLs.
PiperOrigin-RevId: 202993121
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e2f43f5810..dcc1e3c8af 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -204,21 +204,23 @@ bool HloSendRecvInstruction::IdenticalSlowPath( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, - int64 channel_id) + HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kSend, ShapeUtil::MakeTupleShape( {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), channel_id) { AppendOperand(operand); + AppendOperand(token); } std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloSendInstruction>(new_operands[0], channel_id()); + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1], + channel_id()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) @@ -238,19 +240,22 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( } // Recv instruction produces a tuple of {receive buffer, U32 context}. -HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) +HloRecvInstruction::HloRecvInstruction(const Shape& shape, + HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kRecv, ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), - channel_id) {} + channel_id) { + AppendOperand(token); +} std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); + CHECK_EQ(new_operands.size(), 1); return MakeUnique<HloRecvInstruction>( - ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id()); } HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) |