aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-07-02 12:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 12:27:03 -0700
commit20e27ad56b95e19ebeb23e34db1aff22e0bd473e (patch)
tree0ff0a458b28a527acbc82b6ae1f5a36aeb96d1c8 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent0967cbb9a34b69ec14238802460971abbec9cbb4 (diff)
Change Send and Recv HLOs to take a token operand.
Send and Recv HLOs now have an additional required operand which must be token-shaped. XLA client interface for these operations is unchanged and will be updated in follow up CLs. PiperOrigin-RevId: 202993121
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc19
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e2f43f5810..dcc1e3c8af 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -204,21 +204,23 @@ bool HloSendRecvInstruction::IdenticalSlowPath(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
- int64 channel_id)
+ HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kSend,
ShapeUtil::MakeTupleShape(
{CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}),
channel_id) {
AppendOperand(operand);
+ AppendOperand(token);
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendInstruction>(new_operands[0], channel_id());
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
+ channel_id());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
@@ -238,19 +240,22 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
}
// Recv instruction produces a tuple of {receive buffer, U32 context}.
-HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id)
+HloRecvInstruction::HloRecvInstruction(const Shape& shape,
+ HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kRecv,
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
- channel_id) {}
+ channel_id) {
+ AppendOperand(token);
+}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 0);
+ CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvInstruction>(
- ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
+ ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id());
}
HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)