aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.cc')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc147
1 files changed, 113 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index aac7df4383..a9a4b3bc5d 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1845,10 +1845,6 @@ XlaOp XlaBuilder::CrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- if (channel_id.has_value()) {
- return Unimplemented("channel_id is not supported in AllReduce");
- }
-
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1858,6 +1854,10 @@ XlaOp XlaBuilder::CrossReplicaSum(
instr.add_replica_group_ids(replica_group_id);
}
+ if (channel_id.has_value()) {
+ instr.set_all_reduce_id(channel_id->handle());
+ }
+
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
@@ -1940,28 +1940,17 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
- // Send instruction produces a tuple of {aliased operand, U32 context,
- // token}.
- HloInstructionProto send_instr;
- TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
- *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
- {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
- send_instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(XlaOp send,
- AddInstruction(std::move(send_instr), HloOpcode::kSend,
- {operand, token}));
-
- HloInstructionProto send_done_instr;
- *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
- send_done_instr.set_channel_id(handle.handle());
- return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
- {send});
+ return SendWithToken(operand, token, handle);
});
}
XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
+ return InvalidArgument("Send must use a device-to-device channel");
+ }
+
// Send instruction produces a tuple of {aliased operand, U32 context,
// token}.
HloInstructionProto send_instr;
@@ -1992,6 +1981,27 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
+ XlaOp recv = RecvWithToken(token, shape, handle);
+
+ // The RecvDone instruction produces a tuple of the data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto recv_data;
+ *recv_data.mutable_shape() = shape;
+ recv_data.set_tuple_index(0);
+ return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
+ {recv});
+ });
+}
+
+XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
+ return InvalidArgument("Recv must use a device-to-device channel");
+ }
+
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
@@ -2005,31 +2015,81 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
recv_done_instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(XlaOp recv_done,
- AddInstruction(std::move(recv_done_instr),
- HloOpcode::kRecvDone, {recv}));
+ return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
+ {recv});
+ });
+}
- // The RecvDone instruction produces a tuple of the data and a token
- // type. Return XLA op containing the data.
- // TODO(b/80000000): Remove this when clients have been updated to handle
- // tokens.
- HloInstructionProto recv_data;
- *recv_data.mutable_shape() = shape;
- recv_data.set_tuple_index(0);
- return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
- {recv_done});
+XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Shape passed to SendToHost must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "SendToHost shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ // TODO(b/111544877): Support tuple shapes.
+ if (!ShapeUtil::IsArray(operand_shape)) {
+ return InvalidArgument("SendToHost only supports array shapes, shape: %s",
+ ShapeUtil::HumanString(operand_shape).c_str());
+ }
+
+ if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
+ return InvalidArgument("SendToHost must use a device-to-host channel");
+ }
+
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape_with_layout, ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ send_instr.set_is_host_transfer(true);
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
+
+ HloInstructionProto send_done_instr;
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ send_done_instr.set_channel_id(handle.handle());
+ send_done_instr.set_is_host_transfer(true);
+ return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
+ {send});
});
}
-XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
- const ChannelHandle& handle) {
+XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Shape passed to RecvFromHost must have a layout");
+ }
+
+ // TODO(b/111544877): Support tuple shapes.
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument(
+ "RecvFromHost only supports array shapes, shape: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+
+ if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
+ return InvalidArgument("RecvFromHost must use a host-to-device channel");
+ }
+
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
recv_instr.set_channel_id(handle.handle());
+ recv_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));
@@ -2037,6 +2097,7 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
recv_done_instr.set_channel_id(handle.handle());
+ recv_done_instr.set_is_host_transfer(true);
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
});
@@ -2760,6 +2821,17 @@ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
return token.builder()->RecvWithToken(token, shape, handle);
}
+XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle) {
+ return operand.builder()->SendToHost(operand, token, shape_with_layout,
+ handle);
+}
+
+XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return token.builder()->RecvFromHost(token, shape, handle);
+}
+
XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
const string& config) {
return token.builder()->InfeedWithToken(token, shape, config);
@@ -2801,4 +2873,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
grad_output, epsilon, feature_index);
}
+XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size});
+ return builder->ReportErrorOrReturn(
+ builder->AddInstruction(std::move(instr), HloOpcode::kIota));
+}
+
} // namespace xla