diff options
author | Mark Heffernan <meheff@google.com> | 2018-07-02 12:22:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-02 12:27:03 -0700 |
commit | 20e27ad56b95e19ebeb23e34db1aff22e0bd473e (patch) | |
tree | 0ff0a458b28a527acbc82b6ae1f5a36aeb96d1c8 /tensorflow | |
parent | 0967cbb9a34b69ec14238802460971abbec9cbb4 (diff) |
Change Send and Recv HLOs to take a token operand.
Send and Recv HLOs now have an additional required operand which must be token-shaped. XLA client interface for these operations is unchanged and will be updated in follow up CLs.
PiperOrigin-RevId: 202993121
Diffstat (limited to 'tensorflow')
20 files changed, 164 insertions, 94 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 95342af6a7..09e7e87918 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -48,6 +48,7 @@ int64 GetUniqueId() { // computation. bool CanBeRoot(HloOpcode opcode) { switch (opcode) { + case HloOpcode::kAfterAll: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kOutfeed: @@ -1586,6 +1587,7 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( operand_shape, init_shape, dimensions_to_reduce, @@ -1839,16 +1841,24 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; + // Send HLO takes two operands: a data operand and a token. Generate the + // token to pass into the send. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); // Send instruction produces a tuple of {aliased operand, U32 context}. + HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *instr.mutable_shape() = + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN( - XlaOp send, - AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + send_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); HloInstructionProto send_done_instr; *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); @@ -1860,14 +1870,22 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; + // Recv HLO takes a single token operand. Generate the token to pass into + // the Recv and RecvDone instructions. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); // Recv instruction produces a tuple of {receive buffer, U32 context}. - *instr.mutable_shape() = + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, - AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + recv_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = shape; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index f623aef67a..7833ebe73b 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); auto send = builder.AddInstruction( - HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c38719d50e..68f6ffc6b7 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -119,10 +119,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = + true_computation->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( true_computation->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))), - /*channel_id=*/0)); + token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -133,8 +135,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = + true_computation->AddInstruction(HloInstruction::CreateAfterAll({})); auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 35ecd4428d..436d103f23 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -51,14 +51,18 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce operation. + // Skip Constant, Parameter, Reduce, and AfterAll operation. // TODO(b/35975797): Enable Reduce operation once arbitrary computation // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one + // operand in which case constant folding will be impossible and this + // special case is not necessary. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce) { + instruction->opcode() == HloOpcode::kReduce || + instruction->opcode() == HloOpcode::kAfterAll) { continue; } // Skip instructions with non-constant operands. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 0ea8bdcab6..70254e2c1a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1158,15 +1158,16 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto send = builder.AddInstruction( - HloInstruction::CreateSend(param, /*channel_id=*/0)); + HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 4); + EXPECT_EQ(analysis.values().size(), 5); EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); @@ -1181,15 +1182,16 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { // Test that a RecvDone forwards its operand tuple element at {0} to the // output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 3); + EXPECT_EQ(analysis.values().size(), 4); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 2822ecd788..f5524dc6fe 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -75,19 +75,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); } TEST_F(HloDceTest, DeadParameters) { diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index abc5b1c8ef..c1412f7c68 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -201,9 +201,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + token = token[] after-all() + b = (f32[4], u32[]) send(a, token), channel_id=1, sharding={maximal device=0} c = () send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[]) recv(token), channel_id=2, sharding={maximal device=0} e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} f = f32[4] add(a, e) g = f32[4] subtract(a, e) @@ -238,10 +239,11 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + token = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=-1} b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} c = f32[4] add(b, b), sharding={maximal device=-1} - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=-1} ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,10 +261,11 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + token = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=0} b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} c = f32[4] add(b, b) - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=0} ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} } )"; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e0e3d301be..5b416d9654 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -112,10 +112,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Send instruction should have 1 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Send instruction should have 2 operand but sees " << proto.operand_ids_size(); - instruction = CreateSend(operands(0), proto.channel_id()); + instruction = CreateSend(operands(0), operands(1), proto.channel_id()); break; case HloOpcode::kSendDone: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -124,11 +124,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction = CreateSendDone(operands(0)); break; case HloOpcode::kRecv: - TF_RET_CHECK(proto.operand_ids_size() == 0) - << "Recv instruction should have 0 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = - CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + proto.channel_id()); break; case HloOpcode::kRecvDone: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -650,8 +650,8 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( - HloInstruction* operand, int64 channel_id) { - return MakeUnique<HloSendInstruction>(operand, channel_id); + HloInstruction* operand, HloInstruction* token, int64 channel_id) { + return MakeUnique<HloSendInstruction>(operand, token, channel_id); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone( @@ -663,8 +663,8 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv( - const Shape& shape, int64 channel_id) { - return MakeUnique<HloRecvInstruction>(shape, channel_id); + const Shape& shape, HloInstruction* token, int64 channel_id) { + return MakeUnique<HloRecvInstruction>(shape, token, channel_id); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 0459072127..34e7dcb43d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -477,7 +477,7 @@ class HloInstruction { const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are + // TODO(b/80000000): Remove this overload when all uses of outfeed are // converted to take tokens. static std::unique_ptr<HloInstruction> CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, @@ -487,6 +487,7 @@ class HloInstruction { // initiates sending the operand data to a unique receive instruction in // another computation that has the same channel id. static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand, + HloInstruction* token, int64 channel_id); // Blocks until data transfer for the Send instruction (operand) is complete. @@ -498,6 +499,7 @@ class HloInstruction { // which allocates resources to receive data of the given shape from a unique // send instruction in another computation that has the same channel id. static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape, + HloInstruction* token, int64 channel_id); // Blocks until data transfer for the Recv instruction (operand) is complete diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e2f43f5810..dcc1e3c8af 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -204,21 +204,23 @@ bool HloSendRecvInstruction::IdenticalSlowPath( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, - int64 channel_id) + HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kSend, ShapeUtil::MakeTupleShape( {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), channel_id) { AppendOperand(operand); + AppendOperand(token); } std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloSendInstruction>(new_operands[0], channel_id()); + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1], + channel_id()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) @@ -238,19 +240,22 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( } // Recv instruction produces a tuple of {receive buffer, U32 context}. -HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) +HloRecvInstruction::HloRecvInstruction(const Shape& shape, + HloInstruction* token, int64 channel_id) : HloSendRecvInstruction( HloOpcode::kRecv, ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), - channel_id) {} + channel_id) { + AppendOperand(token); +} std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); + CHECK_EQ(new_operands.size(), 1); return MakeUnique<HloRecvInstruction>( - ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id()); } HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ec8a42bd3b..df6969c410 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -161,7 +161,8 @@ class HloSendRecvInstruction : public HloInstruction { class HloSendInstruction : public HloSendRecvInstruction { public: - explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, + int64 channel_id); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -185,7 +186,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { class HloRecvInstruction : public HloSendRecvInstruction { public: - explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, + int64 channel_id); private: // Implementation for non-common logic of CloneWithNewOperands. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 6ffed62a09..5b0f09a498 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -670,12 +670,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRecv: { optional<tensorflow::int64> channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/0) || + if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + instruction = builder->AddInstruction(HloInstruction::CreateRecv( + shape.tuple_shapes(0), operands[0], *channel_id)); break; } case HloOpcode::kRecvDone: { @@ -695,12 +695,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSend: { optional<tensorflow::int64> channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || + if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], *channel_id)); + HloInstruction::CreateSend(operands[0], operands[1], *channel_id)); break; } case HloOpcode::kSendDone: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 504ea3fe7a..f40cd60907 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -278,10 +278,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} + %token = token[] after-all() + %recv = (f32[], u32[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} } @@ -1221,10 +1222,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv + %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv %send-done = () send-done((f32[], u32[]) %send), channel_id=16 } @@ -1237,10 +1239,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[]) send(f32[] %constant) + %send = (f32[], u32[]) send(f32[] %constant, token[] %token) %send-done = () send-done((f32[], u32[]) %send), channel_id=16 } @@ -1253,10 +1256,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[]) recv(token[] %token), channel_id=15 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} %send-done = () send-done((f32[], u32[]) %send), channel_id=16 } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 27c9529b11..765245096b 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -108,17 +108,29 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } +namespace { + +Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { + const HloInstruction* token = instruction->operand(operand_no); + if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand %lld to be token-shaped, actual shape is" + "%s:\n%s", + operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +} // namespace + Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); // Infeed has an optional single token operand. // TODO(b/80000000): Update when token is not optional. - if (infeed->operand_count() == 1 && - !ShapeUtil::Equal(infeed->operand(0)->shape(), - ShapeUtil::MakeTokenShape())) { - return InternalError( - "Expected infeed operand to be token-shaped, actual shape is %s:\n%s", - ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(), - infeed->ToString().c_str()); + if (infeed->operand_count() == 1) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); } // The output of infeed is a tuple containing the data value and a token. @@ -131,13 +143,8 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction); // Outfeed has an optional token operand (operand 1). // TODO(b/80000000): Update when token is not optional. - if (outfeed->operand_count() == 2 && - !ShapeUtil::Equal(outfeed->operand(1)->shape(), - ShapeUtil::MakeTokenShape())) { - return InternalError( - "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(), - outfeed->ToString().c_str()); + if (outfeed->operand_count() == 2) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); } // Outfeed has a separate shape field for the value which is outfed to the @@ -338,6 +345,7 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { const HloInstruction* send_done = send->users().front(); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1)); return CheckShape( send, ShapeUtil::MakeTupleShape( {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); @@ -348,6 +356,7 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { const HloInstruction* send = send_done->operand(0); TF_RET_CHECK(send->opcode() == HloOpcode::kSend); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); } @@ -356,6 +365,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { const HloInstruction* recv_done = recv->users().front(); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0)); return CheckShape(recv, ShapeUtil::MakeTupleShape( {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 21db233899..bb7231c8c8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); @@ -258,7 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -288,7 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -321,7 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); HloInstruction* unary1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0)); HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); @@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 67e2cf6c77..4cd584bf8b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -829,10 +829,11 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1} + token = token[] after-all() + recv = (f32[2,2], u32[]) recv(token), channel_id=1, sharding={maximal device=1} ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, sharding={maximal device=1} - send = (f32[2,2], u32[]) send(gte), channel_id=1, + send = (f32[2,2], u32[]) send(gte, token), channel_id=1, sharding={maximal device=0} send-done = () send-done(send), channel_id=1, sharding={maximal device=0} } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d05e995a95..81f071ecc5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -69,11 +69,11 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } const Shape& accumulator_shape = reducer_shape.result(); - if (ShapeUtil::Rank(accumulator_shape) != 0) { + if (!ShapeUtil::IsArray(accumulator_shape) || + ShapeUtil::Rank(accumulator_shape) != 0) { return InvalidArgument( - "Reduction function must have rank 0 (rank %lld reduction function " - "given).", - ShapeUtil::Rank(accumulator_shape)); + "Reduction function must produce a scalar but has shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); } // Check that the accumulator can be passed in as the first argument. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 5734f28407..a8f885fd86 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -318,8 +318,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto send = builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); BuildModuleAndRunAnalysis(builder.Build()); @@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { // RecvDone forwards its operand tuple element at {0} to the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 0536c99b67..3c83049216 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))), + token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); auto* recv = while_body->AddInstruction( - HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index f5331280ee..c6bd013a1a 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -67,7 +67,9 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { } TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { - builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + auto token = builder_.AddInstruction(HloInstruction::CreateAfterAll({})); + builder_.AddInstruction( + HloInstruction::CreateSend(zero_sized_param_, token, 0)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); EXPECT_FALSE(changed); } |