aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-07-02 12:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 12:27:03 -0700
commit20e27ad56b95e19ebeb23e34db1aff22e0bd473e (patch)
tree0ff0a458b28a527acbc82b6ae1f5a36aeb96d1c8 /tensorflow
parent0967cbb9a34b69ec14238802460971abbec9cbb4 (diff)
Change Send and Recv HLOs to take a token operand.
Send and Recv HLOs now have an additional required operand which must be token-shaped. XLA client interface for these operations is unchanged and will be updated in follow up CLs. PiperOrigin-RevId: 202993121
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc40
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc38
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc4
20 files changed, 164 insertions, 94 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 95342af6a7..09e7e87918 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -48,6 +48,7 @@ int64 GetUniqueId() {
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
+ case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
@@ -1586,6 +1587,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
+
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
operand_shape, init_shape, dimensions_to_reduce,
@@ -1839,16 +1841,24 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ // Send HLO takes two operands: a data operand and a token. Generate the
+ // token to pass into the send.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
// Send instruction produces a tuple of {aliased operand, U32 context}.
+ HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
- *instr.mutable_shape() =
+ *send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(
- XlaOp send,
- AddInstruction(std::move(instr), HloOpcode::kSend, {operand}));
+ send_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
@@ -1860,14 +1870,22 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ // Recv HLO takes a single token operand. Generate the token to pass into
+ // the Recv and RecvDone instructions.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
// Recv instruction produces a tuple of {receive buffer, U32 context}.
- *instr.mutable_shape() =
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(XlaOp recv,
- AddInstruction(std::move(instr), HloOpcode::kRecv, {}));
+ recv_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() = shape;
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index f623aef67a..7833ebe73b 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(
- HloInstruction::CreateRecv(vec_, /*channel_id=*/0));
+ HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(recv_done, /*channel_id=*/1));
+ HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index c38719d50e..68f6ffc6b7 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -119,10 +119,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
+ auto* token =
+ true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
true_computation->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
- /*channel_id=*/0));
+ token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
@@ -133,8 +135,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
+ auto* token =
+ true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
- ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0));
+ ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 35ecd4428d..436d103f23 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -51,14 +51,18 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
- // Skip Constant, Parameter, Reduce operation.
+ // Skip Constant, Parameter, Reduce, and AfterAll operation.
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
// are supported by the evaluator.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
+ // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
+ // operand in which case constant folding will be impossible and this
+ // special case is not necessary.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kReduce) {
+ instruction->opcode() == HloOpcode::kReduce ||
+ instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
// Skip instructions with non-constant operands.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 0ea8bdcab6..70254e2c1a 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1158,15 +1158,16 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(param, /*channel_id=*/0));
+ HloInstruction::CreateSend(param, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 4);
+ EXPECT_EQ(analysis.values().size(), 5);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
@@ -1181,15 +1182,16 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
// Test that a RecvDone forwards its operand tuple element at {0} to the
// output.
auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(
- HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
+ HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 3);
+ EXPECT_EQ(analysis.values().size(), 4);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 2822ecd788..f5524dc6fe 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -75,19 +75,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
builder.AddInstruction(
- HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
builder.AddInstruction(HloInstruction::CreateTuple({}));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(4, computation->instruction_count());
HloDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
- EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(4, computation->instruction_count());
}
TEST_F(HloDceTest, DeadParameters) {
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index abc5b1c8ef..c1412f7c68 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -201,9 +201,10 @@ HloModule Module
ENTRY entry {
p0 = (f32[4]) parameter(0)
a = f32[4] get-tuple-element(p0), index=0
- b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0}
+ token = token[] after-all()
+ b = (f32[4], u32[]) send(a, token), channel_id=1, sharding={maximal device=0}
c = () send-done(b), channel_id=1, sharding={maximal device=0}
- d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0}
+ d = (f32[4], u32[]) recv(token), channel_id=2, sharding={maximal device=0}
e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0}
f = f32[4] add(a, e)
g = f32[4] subtract(a, e)
@@ -238,10 +239,11 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
HloModule Module
ENTRY entry {
- a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1}
+ token = token[] after-all(), sharding={maximal device=-1}
+ a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=-1}
b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1}
c = f32[4] add(b, b), sharding={maximal device=-1}
- d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1}
+ d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=-1}
ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1}
}
)";
@@ -259,10 +261,11 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
HloModule Module
ENTRY entry {
- a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0}
+ token = token[] after-all(), sharding={maximal device=0}
+ a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=0}
b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0}
c = f32[4] add(b, b)
- d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0}
+ d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=0}
ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0}
}
)";
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e0e3d301be..5b416d9654 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -112,10 +112,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kSend:
- TF_RET_CHECK(proto.operand_ids_size() == 1)
- << "Send instruction should have 1 operand but sees "
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Send instruction should have 2 operand but sees "
<< proto.operand_ids_size();
- instruction = CreateSend(operands(0), proto.channel_id());
+ instruction = CreateSend(operands(0), operands(1), proto.channel_id());
break;
case HloOpcode::kSendDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
@@ -124,11 +124,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction = CreateSendDone(operands(0));
break;
case HloOpcode::kRecv:
- TF_RET_CHECK(proto.operand_ids_size() == 0)
- << "Recv instruction should have 0 operand but sees "
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Recv instruction should have 1 operand but sees "
<< proto.operand_ids_size();
- instruction =
- CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id());
+ instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
+ proto.channel_id());
break;
case HloOpcode::kRecvDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
@@ -650,8 +650,8 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
- HloInstruction* operand, int64 channel_id) {
- return MakeUnique<HloSendInstruction>(operand, channel_id);
+ HloInstruction* operand, HloInstruction* token, int64 channel_id) {
+ return MakeUnique<HloSendInstruction>(operand, token, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
@@ -663,8 +663,8 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
- const Shape& shape, int64 channel_id) {
- return MakeUnique<HloRecvInstruction>(shape, channel_id);
+ const Shape& shape, HloInstruction* token, int64 channel_id) {
+ return MakeUnique<HloRecvInstruction>(shape, token, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 0459072127..34e7dcb43d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -477,7 +477,7 @@ class HloInstruction {
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
// Overload which does not require a token.
- // TODO(b/80000000): Remove this overload when all uses of infeed are
+ // TODO(b/80000000): Remove this overload when all uses of outfeed are
// converted to take tokens.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
@@ -487,6 +487,7 @@ class HloInstruction {
// initiates sending the operand data to a unique receive instruction in
// another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
+ HloInstruction* token,
int64 channel_id);
// Blocks until data transfer for the Send instruction (operand) is complete.
@@ -498,6 +499,7 @@ class HloInstruction {
// which allocates resources to receive data of the given shape from a unique
// send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
+ HloInstruction* token,
int64 channel_id);
// Blocks until data transfer for the Recv instruction (operand) is complete
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e2f43f5810..dcc1e3c8af 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -204,21 +204,23 @@ bool HloSendRecvInstruction::IdenticalSlowPath(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
- int64 channel_id)
+ HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kSend,
ShapeUtil::MakeTupleShape(
{CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}),
channel_id) {
AppendOperand(operand);
+ AppendOperand(token);
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendInstruction>(new_operands[0], channel_id());
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
+ channel_id());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
@@ -238,19 +240,22 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
}
// Recv instruction produces a tuple of {receive buffer, U32 context}.
-HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id)
+HloRecvInstruction::HloRecvInstruction(const Shape& shape,
+ HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kRecv,
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
- channel_id) {}
+ channel_id) {
+ AppendOperand(token);
+}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 0);
+ CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvInstruction>(
- ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
+ ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id());
}
HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ec8a42bd3b..df6969c410 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -161,7 +161,8 @@ class HloSendRecvInstruction : public HloInstruction {
class HloSendInstruction : public HloSendRecvInstruction {
public:
- explicit HloSendInstruction(HloInstruction* operand, int64 channel_id);
+ explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
+ int64 channel_id);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -185,7 +186,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
class HloRecvInstruction : public HloSendRecvInstruction {
public:
- explicit HloRecvInstruction(const Shape& shape, int64 channel_id);
+ explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
+ int64 channel_id);
private:
// Implementation for non-common logic of CloneWithNewOperands.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 6ffed62a09..5b0f09a498 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -670,12 +670,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kRecv: {
optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/0) ||
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
+ instruction = builder->AddInstruction(HloInstruction::CreateRecv(
+ shape.tuple_shapes(0), operands[0], *channel_id));
break;
}
case HloOpcode::kRecvDone: {
@@ -695,12 +695,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kSend: {
optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], *channel_id));
+ HloInstruction::CreateSend(operands[0], operands[1], *channel_id));
break;
}
case HloOpcode::kSendDone: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 504ea3fe7a..f40cd60907 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -278,10 +278,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
R"(HloModule TwoSendRecvBothWayRecvFist_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
+ %token = token[] after-all()
+ %recv = (f32[], u32[]) recv(token[] %token), channel_id=15, sharding={maximal device=1}
ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
%constant = f32[] constant(2.1), sharding={maximal device=0}
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+ %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
%send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
}
@@ -1221,10 +1222,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
const string original = R"(HloModule unexpected_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
+ %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
@@ -1237,10 +1239,11 @@ TEST_F(HloParserTest, MissingAttribute) {
const string original = R"(HloModule missing_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(-2.1)
- %send = (f32[], u32[]) send(f32[] %constant)
+ %send = (f32[], u32[]) send(f32[] %constant, token[] %token)
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
@@ -1253,10 +1256,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
const string original = R"(HloModule pre_not_found_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
%recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
+ %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done}
%send-done = () send-done((f32[], u32[]) %send), channel_id=16
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 27c9529b11..765245096b 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -108,17 +108,29 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
+namespace {
+
+Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
+ const HloInstruction* token = instruction->operand(operand_no);
+ if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
+ return InternalError(
+ "Expected operand %lld to be token-shaped, actual shape is"
+ "%s:\n%s",
+ operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
+ instruction->ToString().c_str());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
// Infeed has an optional single token operand.
// TODO(b/80000000): Update when token is not optional.
- if (infeed->operand_count() == 1 &&
- !ShapeUtil::Equal(infeed->operand(0)->shape(),
- ShapeUtil::MakeTokenShape())) {
- return InternalError(
- "Expected infeed operand to be token-shaped, actual shape is %s:\n%s",
- ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(),
- infeed->ToString().c_str());
+ if (infeed->operand_count() == 1) {
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
}
// The output of infeed is a tuple containing the data value and a token.
@@ -131,13 +143,8 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
// Outfeed has an optional token operand (operand 1).
// TODO(b/80000000): Update when token is not optional.
- if (outfeed->operand_count() == 2 &&
- !ShapeUtil::Equal(outfeed->operand(1)->shape(),
- ShapeUtil::MakeTokenShape())) {
- return InternalError(
- "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(),
- outfeed->ToString().c_str());
+ if (outfeed->operand_count() == 2) {
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
}
// Outfeed has a separate shape field for the value which is outfed to the
@@ -338,6 +345,7 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
const HloInstruction* send_done = send->users().front();
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1));
return CheckShape(
send, ShapeUtil::MakeTupleShape(
{send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
@@ -348,6 +356,7 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
const HloInstruction* send = send_done->operand(0);
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+
return CheckShape(send_done, ShapeUtil::MakeNil());
}
@@ -356,6 +365,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
const HloInstruction* recv_done = recv->users().front();
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0));
return CheckShape(recv,
ShapeUtil::MakeTupleShape(
{recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 21db233899..bb7231c8c8 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
@@ -258,7 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
add = f32[4,3]{1,0} add(p0, p0)
abs1 = f32[4,3]{1,0} abs(add)
log = f32[4,3]{1,0} log(abs1)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
abs2 = f32[4,3]{1,0} abs(log)
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
})")
@@ -288,7 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
p0 = f32[4,3]{1,0} parameter(0)
add1 = f32[4,3]{1,0} add(p0, p0)
log = f32[4,3]{1,0} log(p0)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
add2 = f32[4,3]{1,0} add(log, add1)
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
})")
@@ -321,7 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
add1 = f32[4,3]{1,0} add(p0, p0)
add2 = f32[4,3]{1,0} add(add1, add1)
log = f32[4,3]{1,0} log(add2)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
sub1 = f32[4,3]{1,0} subtract(log, add2)
sub2 = f32[4,3]{1,0} subtract(add2, add1)
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
@@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
- builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
@@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 67e2cf6c77..4cd584bf8b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -829,10 +829,11 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ENTRY entry_computation {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
- recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1}
+ token = token[] after-all()
+ recv = (f32[2,2], u32[]) recv(token), channel_id=1, sharding={maximal device=1}
ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1,
sharding={maximal device=1}
- send = (f32[2,2], u32[]) send(gte), channel_id=1,
+ send = (f32[2,2], u32[]) send(gte, token), channel_id=1,
sharding={maximal device=0}
send-done = () send-done(send), channel_id=1, sharding={maximal device=0}
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index d05e995a95..81f071ecc5 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -69,11 +69,11 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape,
}
const Shape& accumulator_shape = reducer_shape.result();
- if (ShapeUtil::Rank(accumulator_shape) != 0) {
+ if (!ShapeUtil::IsArray(accumulator_shape) ||
+ ShapeUtil::Rank(accumulator_shape) != 0) {
return InvalidArgument(
- "Reduction function must have rank 0 (rank %lld reduction function "
- "given).",
- ShapeUtil::Rank(accumulator_shape));
+ "Reduction function must produce a scalar but has shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
}
// Check that the accumulator can be passed in as the first argument.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 5734f28407..a8f885fd86 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -318,8 +318,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
BuildModuleAndRunAnalysis(builder.Build());
@@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
- ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
+ ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
BuildModuleAndRunAnalysis(builder.Build());
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 0536c99b67..3c83049216 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* recv = while_body->AddInstruction(
- HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
+ HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
index f5331280ee..c6bd013a1a 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
@@ -67,7 +67,9 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) {
}
TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) {
- builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0));
+ auto token = builder_.AddInstruction(HloInstruction::CreateAfterAll({}));
+ builder_.AddInstruction(
+ HloInstruction::CreateSend(zero_sized_param_, token, 0));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination());
EXPECT_FALSE(changed);
}