aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-07-17 18:01:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 18:07:00 -0700
commit07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (patch)
treec09bbe69d49db3b6a6acd6f38545915ae2369035 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent1662a105497e60d002e101161987cbbd48ba06c6 (diff)
Add single-sided host send and receive operations.
Adds a bit on kSend/kReceive instructions and their Done variants indicated whether the operations communicates with the host or another device (the default). Host send/recv operations are single-sided without a complementary recv/send instruction in another module. Host send/recv operations are exposed in the XLA builder API as SendToHost and RecvFromHost. PiperOrigin-RevId: 205008138
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc44
1 files changed, 29 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 702f808449..df26a2c744 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -181,8 +181,11 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
- int64 channel_id)
- : HloInstruction(opcode, shape), channel_id_(channel_id) {}
+ int64 channel_id,
+ bool is_host_transfer)
+ : HloInstruction(opcode, shape),
+ channel_id_(channel_id),
+ is_host_transfer_(is_host_transfer) {}
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
@@ -192,7 +195,12 @@ HloInstructionProto HloSendRecvInstruction::ToProto() const {
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("channel_id=", channel_id_)};
+ std::vector<string> attrs;
+ attrs.push_back(StrCat("channel_id=", channel_id_));
+ if (is_host_transfer()) {
+ attrs.push_back("is_host_transfer=true");
+ }
+ return attrs;
}
bool HloSendRecvInstruction::IdenticalSlowPath(
@@ -205,13 +213,14 @@ bool HloSendRecvInstruction::IdenticalSlowPath(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
- HloInstruction* token, int64 channel_id)
+ HloInstruction* token, int64 channel_id,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kSend,
ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()}),
- channel_id) {
+ channel_id, is_host_transfer) {
AppendOperand(operand);
AppendOperand(token);
}
@@ -222,12 +231,14 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
- channel_id());
+ channel_id(), is_host_transfer());
}
-HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
+HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
+ bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
- CHECK_NOTNULL(operand)->channel_id()) {
+ CHECK_NOTNULL(operand)->channel_id(),
+ is_host_transfer) {
AppendOperand(operand);
}
@@ -238,17 +249,18 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloSendDoneInstruction>(
- Cast<HloSendInstruction>(new_operands[0]));
+ Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
}
// Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction::HloRecvInstruction(const Shape& shape,
- HloInstruction* token, int64 channel_id)
+ HloInstruction* token, int64 channel_id,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecv,
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()}),
- channel_id) {
+ channel_id, is_host_transfer) {
AppendOperand(token);
}
@@ -258,16 +270,18 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvInstruction>(
- ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id());
+ ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
+ is_host_transfer());
}
-HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
+HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecvDone,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
ShapeUtil::MakeTokenShape()}),
- CHECK_NOTNULL(operand)->channel_id()) {
+ CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
AppendOperand(operand);
}
@@ -278,7 +292,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvDoneInstruction>(
- Cast<HloRecvInstruction>(new_operands[0]));
+ Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
HloAllReduceInstruction::HloAllReduceInstruction(