diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 431 |
1 files changed, 347 insertions, 84 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 0145f60483..152335e22a 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -48,6 +49,7 @@ int64 GetUniqueId() { // computation. bool CanBeRoot(HloOpcode opcode) { switch (opcode) { + case HloOpcode::kAfterAll: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kOutfeed: @@ -60,36 +62,18 @@ bool CanBeRoot(HloOpcode opcode) { } // namespace -XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); } -XlaOp operator+(const XlaOp& x, const XlaOp& y) { - return x.builder()->Add(x, y); -} -XlaOp operator-(const XlaOp& x, const XlaOp& y) { - return x.builder()->Sub(x, y); -} -XlaOp operator*(const XlaOp& x, const XlaOp& y) { - return x.builder()->Mul(x, y); -} -XlaOp operator/(const XlaOp& x, const XlaOp& y) { - return x.builder()->Div(x, y); -} -XlaOp operator%(const XlaOp& x, const XlaOp& y) { - return x.builder()->Rem(x, y); -} +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } -XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); } -XlaOp operator&(const XlaOp& x, const XlaOp& y) { - return x.builder()->And(x, y); -} -XlaOp operator|(const XlaOp& x, const XlaOp& y) { - return x.builder()->Or(x, y); -} -XlaOp operator^(const XlaOp& x, const XlaOp& y) { - return x.builder()->Xor(x, y); -} -XlaOp operator<<(const XlaOp& x, const XlaOp& y) { - return x.builder()->ShiftLeft(x, y); -} +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } XlaOp operator>>(const XlaOp& x, const XlaOp& y) { XlaBuilder* builder = x.builder(); @@ -101,9 +85,9 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { ShapeUtil::HumanString(shape).c_str()); } if (ShapeUtil::ElementIsSigned(shape)) { - return builder->ShiftRightArithmetic(x, y); + return ShiftRightArithmetic(x, y); } else { - return builder->ShiftRightLogical(x, y); + return ShiftRightLogical(x, y); } }); } @@ -550,6 +534,14 @@ XlaOp XlaBuilder::Broadcast( }); } +XlaOp XlaBuilder::BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + return InDimBroadcast(shape, operand, broadcast_dimensions); + }); +} + StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); @@ -745,14 +737,22 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { - return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); + TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); + TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == + ShapeUtil::IsTuple(false_shape)); + HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect + : HloOpcode::kSelect; + return TernaryOp(opcode, pred, on_true, on_false); + }); } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { @@ -1118,6 +1118,35 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { }); } +XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; + instr.set_infeed_config(config); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); + }); +} + void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { @@ -1163,6 +1192,53 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, }); } +XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); + }); +} + +XlaOp XlaBuilder::CreateToken() { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll); + }); +} + +XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (tokens.empty()) { + return InvalidArgument("AfterAll requires at least one operand"); + } + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); + }); +} + XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice<XlaOp> operands, const Shape& shape) { @@ -1366,13 +1442,31 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); -} - -XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5), - /*broadcast_dimensions=*/{}); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values, + int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + std::vector<const Shape*> operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + if (dimension == -1) { + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + dimension = ShapeUtil::Rank(keys_shape) - 1; + } + instr.add_dimensions(dimension); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, @@ -1405,16 +1499,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0), - /*broadcast_dimensions=*/{}); -} - -XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnaryOp(HloOpcode::kNegate, operand); } @@ -1594,6 +1678,7 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( operand_shape, init_shape, dimensions_to_reduce, @@ -1761,10 +1846,6 @@ XlaOp XlaBuilder::CrossReplicaSum( tensorflow::gtl::ArraySlice<int64> replica_group_ids, const tensorflow::gtl::optional<ChannelHandle>& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - if (channel_id.has_value()) { - return Unimplemented("channel_id is not supported in AllReduce"); - } - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1774,6 +1855,10 @@ XlaOp XlaBuilder::CrossReplicaSum( instr.add_replica_group_ids(replica_group_id); } + if (channel_id.has_value()) { + instr.set_all_reduce_id(channel_id->handle()); + } + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, @@ -1847,19 +1932,39 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; + // Send HLO takes two operands: a data operand and a token. Generate the + // token to pass into the send. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + + return SendWithToken(operand, token, handle); + }); +} + +XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Send must use a device-to-device channel"); + } - // Send instruction produces a tuple of {aliased operand, U32 context}. + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN( - XlaOp send, - AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -1868,18 +1973,132 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; + // Recv HLO takes a single token operand. Generate the token to pass into + // the Recv and RecvDone instructions. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); - // Recv instruction produces a tuple of {receive buffer, U32 context}. - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, - AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + XlaOp recv = RecvWithToken(token, shape, handle); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv}); + }); +} + +XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Recv must use a device-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Shape passed to SendToHost must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "SendToHost shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(operand_shape)) { + return InvalidArgument("SendToHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(operand_shape).c_str()); + } + + if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("SendToHost must use a device-to-host channel"); + } + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape_with_layout, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + send_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + send_done_instr.set_channel_id(handle.handle()); + send_done_instr.set_is_host_transfer(true); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape passed to RecvFromHost must have a layout"); + } + + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "RecvFromHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(shape).c_str()); + } + + if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { + return InvalidArgument("RecvFromHost must use a host-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + recv_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = shape; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); + recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); }); @@ -2140,6 +2359,13 @@ XlaOp Broadcast(const XlaOp& operand, return operand.builder()->Broadcast(operand, broadcast_sizes); } +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return operand.builder()->BroadcastInDim(operand, shape, + broadcast_dimensions); +} + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { return operand.builder()->Pad(operand, padding_value, padding_config); @@ -2498,14 +2724,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } -XlaOp SqrtF32(const XlaOp& operand) { - return operand.builder()->SqrtF32(operand); -} - -XlaOp SquareF32(const XlaOp& operand) { - return operand.builder()->SquareF32(operand); -} - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); @@ -2523,10 +2741,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp ReciprocalF32(const XlaOp& operand) { - return operand.builder()->ReciprocalF32(operand); -} - XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } XlaOp Transpose(const XlaOp& operand, @@ -2538,7 +2752,10 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); } +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values, + int64 dimension) { + return keys.builder()->Sort(keys, std::move(values), dimension); +} XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); @@ -2595,6 +2812,45 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape, return builder->Recv(shape, handle); } +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return operand.builder()->SendWithToken(operand, token, handle); +} + +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvWithToken(token, shape, handle); +} + +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle) { + return operand.builder()->SendToHost(operand, token, shape_with_layout, + handle); +} + +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvFromHost(token, shape, handle); +} + +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return token.builder()->InfeedWithToken(token, shape, config); +} + +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout, + outfeed_config); +} + +XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } + +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) { + return builder->AfterAll(tokens); +} + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { @@ -2618,4 +2874,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); + return builder->ReportErrorOrReturn( + builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +} + } // namespace xla |