aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.cc')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc431
1 files changed, 347 insertions, 84 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 0145f60483..152335e22a 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/sharding_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -48,6 +49,7 @@ int64 GetUniqueId() {
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
+ case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
@@ -60,36 +62,18 @@ bool CanBeRoot(HloOpcode opcode) {
} // namespace
-XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); }
-XlaOp operator+(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Add(x, y);
-}
-XlaOp operator-(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Sub(x, y);
-}
-XlaOp operator*(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Mul(x, y);
-}
-XlaOp operator/(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Div(x, y);
-}
-XlaOp operator%(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Rem(x, y);
-}
+XlaOp operator-(const XlaOp& x) { return Neg(x); }
+XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); }
+XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); }
+XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); }
+XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); }
+XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); }
-XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); }
-XlaOp operator&(const XlaOp& x, const XlaOp& y) {
- return x.builder()->And(x, y);
-}
-XlaOp operator|(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Or(x, y);
-}
-XlaOp operator^(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Xor(x, y);
-}
-XlaOp operator<<(const XlaOp& x, const XlaOp& y) {
- return x.builder()->ShiftLeft(x, y);
-}
+XlaOp operator~(const XlaOp& x) { return Not(x); }
+XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); }
+XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); }
+XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); }
+XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); }
XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
XlaBuilder* builder = x.builder();
@@ -101,9 +85,9 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
ShapeUtil::HumanString(shape).c_str());
}
if (ShapeUtil::ElementIsSigned(shape)) {
- return builder->ShiftRightArithmetic(x, y);
+ return ShiftRightArithmetic(x, y);
} else {
- return builder->ShiftRightLogical(x, y);
+ return ShiftRightLogical(x, y);
}
});
}
@@ -550,6 +534,14 @@ XlaOp XlaBuilder::Broadcast(
});
}
+XlaOp XlaBuilder::BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return InDimBroadcast(shape, operand, broadcast_dimensions);
+ });
+}
+
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
@@ -745,14 +737,22 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
const XlaOp& on_false) {
- return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true));
+ TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false));
+ TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) ==
+ ShapeUtil::IsTuple(false_shape));
+ HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect
+ : HloOpcode::kSelect;
+ return TernaryOp(opcode, pred, on_true, on_false);
+ });
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
@@ -1118,6 +1118,35 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
});
}
+XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+ const Shape infeed_instruction_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ *instr.mutable_shape() = infeed_instruction_shape;
+ instr.set_infeed_config(config);
+
+ if (ShapeUtil::IsArray(shape) && sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
+ // TODO(b/110793772): Support tiled array-shaped infeeds.
+ return InvalidArgument(
+ "Tiled sharding is not yet supported for array-shaped infeeds");
+ }
+
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ return InvalidArgument(
+ "Replicated sharding is not yet supported for infeeds");
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+ });
+}
+
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1163,6 +1192,53 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
});
}
+XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+
+ // Check and set outfeed shape.
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "Outfeed shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ *instr.mutable_outfeed_shape() = shape_with_layout;
+
+ instr.set_outfeed_config(outfeed_config);
+
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
+ {operand, token});
+ });
+}
+
+XlaOp XlaBuilder::CreateToken() {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
+ });
+}
+
+XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (tokens.empty()) {
+ return InvalidArgument("AfterAll requires at least one operand");
+ }
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
+ });
+}
+
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape) {
@@ -1366,13 +1442,31 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::Sort(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kSort, operand);
-}
-
-XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
+XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ operand_shape_ptrs.push_back(&keys_shape);
+ Shape values_shape;
+ if (values.has_value()) {
+ TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values));
+ operand_shape_ptrs.push_back(&values_shape);
+ }
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, operand_shape_ptrs));
+ if (dimension == -1) {
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ dimension = ShapeUtil::Rank(keys_shape) - 1;
+ }
+ instr.add_dimensions(dimension);
+ return values.has_value()
+ ? AddInstruction(std::move(instr), HloOpcode::kSort,
+ {keys, *values})
+ : AddInstruction(std::move(instr), HloOpcode::kSort, {keys});
+ });
}
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
@@ -1405,16 +1499,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::SquareF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Neg(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNegate, operand);
}
@@ -1594,6 +1678,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
+
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
operand_shape, init_shape, dimensions_to_reduce,
@@ -1761,10 +1846,6 @@ XlaOp XlaBuilder::CrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- if (channel_id.has_value()) {
- return Unimplemented("channel_id is not supported in AllReduce");
- }
-
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1774,6 +1855,10 @@ XlaOp XlaBuilder::CrossReplicaSum(
instr.add_replica_group_ids(replica_group_id);
}
+ if (channel_id.has_value()) {
+ instr.set_all_reduce_id(channel_id->handle());
+ }
+
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
@@ -1847,19 +1932,39 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ // Send HLO takes two operands: a data operand and a token. Generate the
+ // token to pass into the send.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
+
+ return SendWithToken(operand, token, handle);
+ });
+}
+
+XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
+ return InvalidArgument("Send must use a device-to-device channel");
+ }
- // Send instruction produces a tuple of {aliased operand, U32 context}.
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
- *instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(
- XlaOp send,
- AddInstruction(std::move(instr), HloOpcode::kSend, {operand}));
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
HloInstructionProto send_done_instr;
- *send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
send_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
{send});
@@ -1868,18 +1973,132 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ // Recv HLO takes a single token operand. Generate the token to pass into
+ // the Recv and RecvDone instructions.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
- *instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(XlaOp recv,
- AddInstruction(std::move(instr), HloOpcode::kRecv, {}));
+ XlaOp recv = RecvWithToken(token, shape, handle);
+
+ // The RecvDone instruction produces a tuple of the data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto recv_data;
+ *recv_data.mutable_shape() = shape;
+ recv_data.set_tuple_index(0);
+ return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
+ {recv});
+ });
+}
+
+XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
+ return InvalidArgument("Recv must use a device-to-device channel");
+ }
+
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ recv_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
+
+ HloInstructionProto recv_done_instr;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ recv_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
+ {recv});
+ });
+}
+
+XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Shape passed to SendToHost must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "SendToHost shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ // TODO(b/111544877): Support tuple shapes.
+ if (!ShapeUtil::IsArray(operand_shape)) {
+ return InvalidArgument("SendToHost only supports array shapes, shape: %s",
+ ShapeUtil::HumanString(operand_shape).c_str());
+ }
+
+ if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
+ return InvalidArgument("SendToHost must use a device-to-host channel");
+ }
+
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape_with_layout, ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ send_instr.set_is_host_transfer(true);
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
+
+ HloInstructionProto send_done_instr;
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ send_done_instr.set_channel_id(handle.handle());
+ send_done_instr.set_is_host_transfer(true);
+ return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
+ {send});
+ });
+}
+
+XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Shape passed to RecvFromHost must have a layout");
+ }
+
+ // TODO(b/111544877): Support tuple shapes.
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument(
+ "RecvFromHost only supports array shapes, shape: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+
+ if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
+ return InvalidArgument("RecvFromHost must use a host-to-device channel");
+ }
+
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ recv_instr.set_channel_id(handle.handle());
+ recv_instr.set_is_host_transfer(true);
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
- *recv_done_instr.mutable_shape() = shape;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
recv_done_instr.set_channel_id(handle.handle());
+ recv_done_instr.set_is_host_transfer(true);
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
});
@@ -2140,6 +2359,13 @@ XlaOp Broadcast(const XlaOp& operand,
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return operand.builder()->BroadcastInDim(operand, shape,
+ broadcast_dimensions);
+}
+
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
return operand.builder()->Pad(operand, padding_value, padding_config);
@@ -2498,14 +2724,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
-XlaOp SqrtF32(const XlaOp& operand) {
- return operand.builder()->SqrtF32(operand);
-}
-
-XlaOp SquareF32(const XlaOp& operand) {
- return operand.builder()->SquareF32(operand);
-}
-
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
@@ -2523,10 +2741,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
-XlaOp ReciprocalF32(const XlaOp& operand) {
- return operand.builder()->ReciprocalF32(operand);
-}
-
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
XlaOp Transpose(const XlaOp& operand,
@@ -2538,7 +2752,10 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
-XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); }
+XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
+ return keys.builder()->Sort(keys, std::move(values), dimension);
+}
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
return min.builder()->Clamp(min, operand, max);
@@ -2595,6 +2812,45 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape,
return builder->Recv(shape, handle);
}
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return operand.builder()->SendWithToken(operand, token, handle);
+}
+
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return token.builder()->RecvWithToken(token, shape, handle);
+}
+
+XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle) {
+ return operand.builder()->SendToHost(operand, token, shape_with_layout,
+ handle);
+}
+
+XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return token.builder()->RecvFromHost(token, shape, handle);
+}
+
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return token.builder()->InfeedWithToken(token, shape, config);
+}
+
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
+ outfeed_config);
+}
+
+XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
+
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return builder->AfterAll(tokens);
+}
+
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index) {
@@ -2618,4 +2874,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
grad_output, epsilon, feature_index);
}
+XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size});
+ return builder->ReportErrorOrReturn(
+ builder->AddInstruction(std::move(instr), HloOpcode::kIota));
+}
+
} // namespace xla