diff options
author | Mark Heffernan <meheff@google.com> | 2018-07-17 18:01:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-17 18:07:00 -0700 |
commit | 07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (patch) | |
tree | c09bbe69d49db3b6a6acd6f38545915ae2369035 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 1662a105497e60d002e101161987cbbd48ba06c6 (diff) |
Add single-sided host send and receive operations.
Adds a bit on kSend/kReceive instructions and their Done variants indicated whether the operations communicates with the host or another device (the default). Host send/recv operations are single-sided without a complementary recv/send instruction in another module.
Host send/recv operations are exposed in the XLA builder API as SendToHost and RecvFromHost.
PiperOrigin-RevId: 205008138
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 44 |
1 files changed, 29 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 702f808449..df26a2c744 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -181,8 +181,11 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl( HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id) - : HloInstruction(opcode, shape), channel_id_(channel_id) {} + int64 channel_id, + bool is_host_transfer) + : HloInstruction(opcode, shape), + channel_id_(channel_id), + is_host_transfer_(is_host_transfer) {} HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); @@ -192,7 +195,12 @@ HloInstructionProto HloSendRecvInstruction::ToProto() const { std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("channel_id=", channel_id_)}; + std::vector<string> attrs; + attrs.push_back(StrCat("channel_id=", channel_id_)); + if (is_host_transfer()) { + attrs.push_back("is_host_transfer=true"); + } + return attrs; } bool HloSendRecvInstruction::IdenticalSlowPath( @@ -205,13 +213,14 @@ bool HloSendRecvInstruction::IdenticalSlowPath( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, - HloInstruction* token, int64 channel_id) + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kSend, ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}), - channel_id) { + channel_id, is_host_transfer) { AppendOperand(operand); AppendOperand(token); } @@ -222,12 +231,14 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1], - channel_id()); + channel_id(), is_host_transfer()); } -HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer) : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), - CHECK_NOTNULL(operand)->channel_id()) { + CHECK_NOTNULL(operand)->channel_id(), + is_host_transfer) { AppendOperand(operand); } @@ -238,17 +249,18 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique<HloSendDoneInstruction>( - Cast<HloSendInstruction>(new_operands[0])); + Cast<HloSendInstruction>(new_operands[0]), is_host_transfer()); } // Recv instruction produces a tuple of {receive buffer, U32 context}. HloRecvInstruction::HloRecvInstruction(const Shape& shape, - HloInstruction* token, int64 channel_id) + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecv, ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}), - channel_id) { + channel_id, is_host_transfer) { AppendOperand(token); } @@ -258,16 +270,18 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique<HloRecvInstruction>( - ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id()); + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), + is_host_transfer()); } -HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecvDone, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(operand->shape(), 0), ShapeUtil::MakeTokenShape()}), - CHECK_NOTNULL(operand)->channel_id()) { + CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) { AppendOperand(operand); } @@ -278,7 +292,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique<HloRecvDoneInstruction>( - Cast<HloRecvInstruction>(new_operands[0])); + Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer()); } HloAllReduceInstruction::HloAllReduceInstruction( |