diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 147 |
1 files changed, 113 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index aac7df4383..a9a4b3bc5d 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1845,10 +1845,6 @@ XlaOp XlaBuilder::CrossReplicaSum( tensorflow::gtl::ArraySlice<int64> replica_group_ids, const tensorflow::gtl::optional<ChannelHandle>& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - if (channel_id.has_value()) { - return Unimplemented("channel_id is not supported in AllReduce"); - } - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1858,6 +1854,10 @@ XlaOp XlaBuilder::CrossReplicaSum( instr.add_replica_group_ids(replica_group_id); } + if (channel_id.has_value()) { + instr.set_all_reduce_id(channel_id->handle()); + } + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, @@ -1940,28 +1940,17 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); - // Send instruction produces a tuple of {aliased operand, U32 context, - // token}. - HloInstructionProto send_instr; - TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); - send_instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp send, - AddInstruction(std::move(send_instr), HloOpcode::kSend, - {operand, token})); - - HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); - send_done_instr.set_channel_id(handle.handle()); - return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, - {send}); + return SendWithToken(operand, token, handle); }); } XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Send must use a device-to-device channel"); + } + // Send instruction produces a tuple of {aliased operand, U32 context, // token}. HloInstructionProto send_instr; @@ -1992,6 +1981,27 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); + XlaOp recv = RecvWithToken(token, shape, handle); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv}); + }); +} + +XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Recv must use a device-to-device channel"); + } + // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; @@ -2005,31 +2015,81 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { *recv_done_instr.mutable_shape() = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv_done, - AddInstruction(std::move(recv_done_instr), - HloOpcode::kRecvDone, {recv})); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} - // The RecvDone instruction produces a tuple of the data and a token - // type. Return XLA op containing the data. - // TODO(b/80000000): Remove this when clients have been updated to handle - // tokens. - HloInstructionProto recv_data; - *recv_data.mutable_shape() = shape; - recv_data.set_tuple_index(0); - return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, - {recv_done}); +XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Shape passed to SendToHost must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "SendToHost shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(operand_shape)) { + return InvalidArgument("SendToHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(operand_shape).c_str()); + } + + if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("SendToHost must use a device-to-host channel"); + } + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape_with_layout, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + send_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + send_done_instr.set_channel_id(handle.handle()); + send_done_instr.set_is_host_transfer(true); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); }); } -XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, - const ChannelHandle& handle) { +XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape passed to RecvFromHost must have a layout"); + } + + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "RecvFromHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(shape).c_str()); + } + + if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { + return InvalidArgument("RecvFromHost must use a host-to-device channel"); + } + // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); recv_instr.set_channel_id(handle.handle()); + recv_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), HloOpcode::kRecv, {token})); @@ -2037,6 +2097,7 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, *recv_done_instr.mutable_shape() = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); + recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); }); @@ -2760,6 +2821,17 @@ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, return token.builder()->RecvWithToken(token, shape, handle); } +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle) { + return operand.builder()->SendToHost(operand, token, shape_with_layout, + handle); +} + +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvFromHost(token, shape, handle); +} + XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config) { return token.builder()->InfeedWithToken(token, shape, config); @@ -2801,4 +2873,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); + return builder->ReportErrorOrReturn( + builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +} + } // namespace xla |