aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc62
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h27
-rw-r--r--tensorflow/compiler/xla/client/xla_builder_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h23
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc61
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h41
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc45
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h13
-rw-r--r--tensorflow/compiler/xla/xla_data.proto8
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md73
23 files changed, 522 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 073d66bcd2..b3b00e2fff 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1881,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum(
});
}
+XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
+ int64 concat_dimension, int64 split_count,
+ const std::vector<ReplicaGroup>& replica_groups) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+
+ // The HloInstruction for Alltoall currently only handles the data
+ // communication: it accepts N already split parts and scatters them to N
+ // cores, and each core gathers the N received parts into a tuple as the
+ // output. So here we explicitly split the operand before the hlo alltoall,
+ // and concat the tuple elements.
+ //
+ // First, run shape inference to make sure the shapes are valid.
+ TF_RETURN_IF_ERROR(
+ ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
+ concat_dimension, split_count)
+ .status());
+
+ // Split into N parts.
+ std::vector<XlaOp> slices;
+ slices.reserve(split_count);
+ const int64 block_size =
+ operand_shape.dimensions(split_dimension) / split_count;
+ for (int i = 0; i < split_count; i++) {
+ slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
+ /*limit_index=*/(i + 1) * block_size,
+ /*stride=*/1, /*dimno=*/split_dimension));
+ }
+
+ // Handle data communication.
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
+ std::vector<const Shape*> slice_shape_ptrs;
+ c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
+ for (const ReplicaGroup& group : replica_groups) {
+ *instr.add_replica_groups() = group;
+ }
+ TF_ASSIGN_OR_RETURN(
+ XlaOp alltoall,
+ AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
+
+ // Concat the N received parts.
+ std::vector<XlaOp> received;
+ received.reserve(split_count);
+ for (int i = 0; i < split_count; i++) {
+ received.push_back(this->GetTupleElement(alltoall, i));
+ }
+ return this->ConcatInDim(received, concat_dimension);
+ });
+}
+
XlaOp XlaBuilder::SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@@ -2677,6 +2732,13 @@ XlaOp CrossReplicaSum(
replica_group_ids, channel_id);
}
+XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
+ int64 concat_dimension, int64 split_count,
+ const std::vector<ReplicaGroup>& replica_groups) {
+ return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
+ split_count, replica_groups);
+}
+
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 3c5f8c8d53..9403d7ca8d 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -699,9 +699,9 @@ class XlaBuilder {
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
- // - `channel_id`: for Allreduce nodes from different models, if they have the
- // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
- // applied cross models.
+ // - `channel_id`: for Allreduce nodes from different modules, if they have
+ // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
+ // not be applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
@@ -710,6 +710,13 @@ class XlaBuilder {
const tensorflow::gtl::optional<ChannelHandle>& channel_id =
tensorflow::gtl::nullopt);
+ // Enqueues an operation that do an Alltoall of the operand cross cores.
+ //
+ // TODO(b/110096724): This is NOT YET ready to use.
+ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
+ int64 concat_dimension, int64 split_count,
+ const std::vector<ReplicaGroup>& replica_groups);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
@@ -1246,6 +1253,9 @@ class XlaBuilder {
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
+ int64 concat_dimension, int64 split_count,
+ const std::vector<ReplicaGroup>& replica_groups);
friend XlaOp SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@@ -1832,9 +1842,9 @@ XlaOp CrossReplicaSum(
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
-// - `channel_id`: for Allreduce nodes from different models, if they have the
+// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
-// applied cross models.
+// applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
@@ -1842,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
const tensorflow::gtl::optional<ChannelHandle>&
channel_id = tensorflow::gtl::nullopt);
+// Enqueues an operation that do an Alltoall of the operand cross cores.
+//
+// TODO(b/110096724): This is NOT YET ready to use.
+XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
+ int64 concat_dimension, int64 split_count,
+ const std::vector<ReplicaGroup>& replica_groups = {});
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index afe5be29f0..49a15ec3b4 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -305,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) {
EXPECT_THAT(root, op::Transpose(op::Parameter()));
}
+TEST_F(XlaBuilderTest, AllToAll) {
+ XlaBuilder b(TestName());
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
+ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
+ /*split_count=*/2);
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+
+ // AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
+ EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
+ EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
+ EXPECT_TRUE(
+ ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
+}
+
TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 9f86749125..86d57581f8 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -106,6 +106,7 @@ class DfsHloVisitorBase {
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
+ virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index ae8a066d62..617a5a2eb4 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
+ Status HandleAllToAll(HloInstructionPtr crs) override {
+ return DefaultAction(crs);
+ }
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
}
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 0b93d97c11..be9098f555 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -151,8 +151,11 @@ message HloInstructionProto {
// Backend configuration for the instruction. Has backend-specific meaning.
string backend_config = 43;
- // Cross Replica Sum fields.
+ // Cross replica op fields.
+ // TODO(b/112107579): remove replica_group_ids field and always use
+ // replica_groups.
repeated int64 replica_group_ids = 44;
+ repeated ReplicaGroup replica_groups = 49;
int64 all_reduce_id = 45;
string cross_replica_sum_barrier = 46;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index a2cefd2621..1bbb0ff08e 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -543,6 +543,19 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
return Status::OK();
}
+Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
+ // TODO(b/110096724): Compute correct cost here.
+ double flops = 0.0;
+ ShapeUtil::ForEachSubshape(hlo->shape(),
+ [&](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsArray(subshape)) {
+ flops += ShapeUtil::ElementsIn(subshape);
+ }
+ });
+ current_properties_[kFlopsKey] = flops;
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
// cost changes with the implementation and the distribution. For now, assume
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 0a79c92f4a..193a04bea0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
+ Status HandleAllToAll(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
Status HandleHostCompute(const HloInstruction* host_compute) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index bfe83cabf1..1efa6eb5bd 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1048,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kMap:
return kGray;
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kRecv:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 7591b99204..8690f2cdaa 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -320,6 +320,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/*all_reduce_id=*/all_reduce_id);
break;
}
+ case HloOpcode::kAllToAll: {
+ instruction = CreateAllToAll(
+ proto.shape(), all_operands(),
+ /*replica_groups=*/
+ std::vector<ReplicaGroup>(proto.replica_groups().begin(),
+ proto.replica_groups().end()),
+ /*barrier=*/proto.cross_replica_sum_barrier());
+ break;
+ }
case HloOpcode::kConvolution:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
@@ -671,6 +680,14 @@ HloInstruction::CreateCrossReplicaSum(
all_reduce_id);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier) {
+ return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups,
+ barrier);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
@@ -1153,6 +1170,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
@@ -1620,6 +1638,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
@@ -2265,6 +2284,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleFft(this);
case HloOpcode::kCrossReplicaSum:
return visitor->HandleCrossReplicaSum(this);
+ case HloOpcode::kAllToAll:
+ return visitor->HandleAllToAll(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
@@ -3139,12 +3160,23 @@ const std::vector<int64>& HloInstruction::replica_group_ids() const {
return Cast<HloAllReduceInstruction>(this)->replica_group_ids();
}
+const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
+ return Cast<HloAllToAllInstruction>(this)->replica_groups();
+}
+
string HloInstruction::cross_replica_sum_barrier() const {
- return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
+ if (opcode() == HloOpcode::kCrossReplicaSum) {
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
+ }
+ return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
- return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
+ if (opcode() == HloOpcode::kCrossReplicaSum) {
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
+ barrier);
+ }
+ return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier(
barrier);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e722086732..3c575ae6ea 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -449,6 +449,26 @@ class HloInstruction {
tensorflow::StringPiece barrier,
const tensorflow::gtl::optional<int64>& all_reduce_id);
+ // This op handles the communication of an Alltoall operation. On each core,
+ // the operands are N ops in the same shape, where N is the number of cores
+ // participating the Alltoall. Then the N operands are scattered to N cores,
+ // e.g., the ith operand is sent to the ith core. Then each core gathers the
+ // received data into a tuple.
+ //
+ // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
+ // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall
+ // will be applied within subgroups in the specified order. For example,
+ // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied
+ // within replica 1, 2, 3, and in the gather phase, the received blocks will
+ // be concatenated in the order of 1, 2, 3; another Alltoall will be applied
+ // within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
+ //
+ // TODO(b/110096724): This is NOT YET ready to use.
+ static std::unique_ptr<HloInstruction> CreateAllToAll(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier);
+
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
@@ -1414,6 +1434,9 @@ class HloInstruction {
// Delegates to HloAllReduceInstruction::replica_group_ids.
const std::vector<int64>& replica_group_ids() const;
+ // Delegates to HloAllToAllInstruction::replica_groups.
+ const std::vector<ReplicaGroup>& replica_groups() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 1d71a74c40..1de5032670 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -359,6 +359,67 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
cross_replica_sum_barrier(), all_reduce_id());
}
+HloAllToAllInstruction::HloAllToAllInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier)
+ : HloInstruction(HloOpcode::kAllToAll, shape),
+ replica_groups_(replica_groups),
+ cross_replica_sum_barrier_(barrier.begin(), barrier.end()) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+bool HloAllToAllInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
+ return ContainersEqual(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return ContainersEqual(a.replica_ids(),
+ b.replica_ids());
+ }) &&
+ cross_replica_sum_barrier() ==
+ casted_other.cross_replica_sum_barrier();
+}
+
+std::unique_ptr<HloInstruction>
+HloAllToAllInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return MakeUnique<HloAllToAllInstruction>(
+ shape, new_operands, replica_groups(), cross_replica_sum_barrier());
+}
+
+std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> result;
+ std::vector<string> replica_group_str;
+ for (const ReplicaGroup& group : replica_groups()) {
+ replica_group_str.push_back(
+ StrCat("{", Join(group.replica_ids(), ","), "}"));
+ }
+ result.push_back(
+ StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+
+ if (!cross_replica_sum_barrier().empty()) {
+ result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+ }
+
+ return result;
+}
+
+HloInstructionProto HloAllToAllInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_replica_groups() = {replica_groups_.begin(),
+ replica_groups_.end()};
+ proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
+ return proto;
+}
+
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 9250b2b846..9586ad6673 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -273,6 +273,47 @@ class HloAllReduceInstruction : public HloInstruction {
tensorflow::gtl::optional<int64> all_reduce_id_;
};
+class HloAllToAllInstruction : public HloInstruction {
+ public:
+ explicit HloAllToAllInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier);
+
+ const std::vector<ReplicaGroup>& replica_groups() const {
+ return replica_groups_;
+ }
+
+ // TODO(b/110096724): rename this.
+ void set_cross_replica_sum_barrier(string barrier) {
+ cross_replica_sum_barrier_ = barrier;
+ }
+ string cross_replica_sum_barrier() const {
+ return cross_replica_sum_barrier_;
+ }
+
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<ReplicaGroup> replica_groups_;
+
+ // The string representation of the barrier config.
+ string cross_replica_sum_barrier_;
+};
+
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 88531b6f20..ec279867e5 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -47,6 +47,7 @@ namespace xla {
#define HLO_OPCODE_LIST(V) \
V(kAbs, "abs") \
V(kAdd, "add") \
+ V(kAllToAll, "all-to-all") \
V(kAtan2, "atan2") \
V(kBatchNormGrad, "batch-norm-grad") \
V(kBatchNormInference, "batch-norm-inference") \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index de73b38dec..2a8c6ecd92 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -125,6 +125,7 @@ class HloParser {
kFloat,
kString,
kBracedInt64List,
+ kBracedInt64ListList,
kHloComputation,
kFftType,
kWindow,
@@ -205,6 +206,10 @@ class HloParser {
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
+ // 'parse_and_add_item' is an lambda to parse an element in the list and add
+ // the parsed element to the result. It's supposed to capture the result.
+ bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
+ const std::function<bool()>& parse_and_add_item);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
@@ -619,6 +624,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
break;
}
+ case HloOpcode::kAllToAll: {
+ optional<std::vector<std::vector<int64>>> tmp_groups;
+ optional<string> barrier;
+ attrs["replica_groups"] = {/*required=*/false,
+ AttrTy::kBracedInt64ListList, &tmp_groups};
+ attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ std::vector<ReplicaGroup> replica_groups;
+ if (tmp_groups) {
+ c_transform(*tmp_groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
+ shape, operands, replica_groups, barrier ? *barrier : ""));
+ break;
+ }
case HloOpcode::kReshape: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -2244,6 +2271,26 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kBracedInt64ListList: {
+ std::vector<std::vector<tensorflow::int64>> result;
+ auto parse_and_add_item = [&]() {
+ std::vector<tensorflow::int64> item;
+ if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
+ TokKind::kComma, &item)) {
+ return false;
+ }
+ result.push_back(item);
+ return true;
+ };
+ if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item)) {
+ return false;
+ }
+ static_cast<optional<std::vector<std::vector<tensorflow::int64>>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
case AttrTy::kSliceRanges: {
SliceRanges result;
if (!ParseSliceRanges(&result)) {
@@ -2586,6 +2633,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
}
+bool HloParser::ParseList(const TokKind start, const TokKind end,
+ const TokKind delim,
+ const std::function<bool()>& parse_and_add_item) {
+ if (!ParseToken(start, StrCat("expects a list starting with ",
+ TokKindToString(start)))) {
+ return false;
+ }
+ if (lexer_.GetKind() == end) {
+ // empty
+ } else {
+ do {
+ if (!parse_and_add_item()) {
+ return false;
+ }
+ } while (EatIfPresent(delim));
+ }
+ return ParseToken(
+ end, StrCat("expects a list to end with ", TokKindToString(end)));
+}
+
// param_list_to_shape ::= param_list '->' shape
bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 7344679bb6..4cd21841f4 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1072,6 +1072,30 @@ ENTRY CrossReplicaSumWithSubgroups {
)"
},
+// all-to-all
+{
+"AllToAll",
+R"(HloModule AllToAll
+
+ENTRY AllToAll {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={}
+}
+
+)"
+},
+// all-to-all with subgroups
+{
+"AllToAllWithSubgroups",
+R"(HloModule AllToAllWithSubgroups
+
+ENTRY AllToAllWithSubgroups {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc"
+}
+
+)"
+},
// Iota
{
"Iota",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 1a8c206aaf..3fae61f704 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -105,6 +105,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) {
ShapeInference::InferCrossReplicaSumShape(operand_shapes));
}
+Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : hlo->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
+ return CheckShape(hlo,
+ ShapeInference::InferAllToAllTupleShape(operand_shapes));
+}
+
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 7feddaeabf..5a56a44f35 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -45,6 +45,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override;
Status HandleInfeed(HloInstruction*) override;
Status HandleOutfeed(HloInstruction*) override;
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index e2191aedb7..f33942d679 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -120,6 +120,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDomain:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index c888bbf144..a4ea2b28f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1779,6 +1779,51 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return ShapeUtil::MakeTupleShape(operand_shape_values);
}
+/* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
+ const Shape& shape, int64 split_dimension, int64 concat_dimension,
+ int64 split_count) {
+ TF_RET_CHECK(split_count > 0);
+ if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
+ return InvalidArgument(
+ "AllToAll split_dimension %lld is out-of-bounds in shape %s.",
+ split_dimension, ShapeUtil::HumanString(shape).c_str());
+ }
+ if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
+ return InvalidArgument(
+ "AllToAll concat_dimension %lld is out-of-bounds in shape %s.",
+ concat_dimension, ShapeUtil::HumanString(shape).c_str());
+ }
+ if (shape.dimensions(split_dimension) % split_count != 0) {
+ return InvalidArgument(
+ "AllToAll split dimension size %lld must be dividable by split_count "
+ "%lld.",
+ shape.dimensions(split_dimension), split_count);
+ }
+ std::vector<int64> new_dimensions(shape.dimensions().begin(),
+ shape.dimensions().end());
+ new_dimensions[split_dimension] /= split_count;
+ new_dimensions[concat_dimension] *= split_count;
+ return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ // An Alltoall HLO instruction receives N operands (with the same shape) and
+ // returns a tuple that contains N array shapes.
+ TF_RET_CHECK(!operand_shapes.empty());
+ for (int i = 0; i < operand_shapes.size(); i++) {
+ if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
+ return InvalidArgument(
+ "HLO all-to-all has operands with different shapes: the 0th "
+ "operand shape %s, but the %dth operand has shape %s.",
+ ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i,
+ ShapeUtil::HumanString(*operand_shapes[i]).c_str());
+ }
+ }
+
+ return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 33da323b3d..c185b0a1bd 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -119,11 +119,22 @@ class ShapeInference {
const Shape& in, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
- // Infers the shape produced a cross replica sum with the given operand
+ // Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ // Infers final shape of an Alltoall operation that is created by the xla
+ // builder.
+ static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
+ int64 split_dimension,
+ int64 concat_dimension,
+ int64 split_count);
+
+ // Infers the shape of an HLO all-to-all instruction.
+ static StatusOr<Shape> InferAllToAllTupleShape(
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+
// Infers the shape produced by applying the given reduction computation
// shape to the given input operand shape.
//
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index fd784e909c..4c35e93d38 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -561,3 +561,11 @@ message OpSharding {
// to.
repeated OpSharding tuple_shardings = 5;
}
+
+// Describes the replica groups in a cross replica op (e.g., all-reduce and
+// all-to-all).
+message ReplicaGroup {
+ // The ids of the replicas that belongs to the same group. The ordering of the
+ // ids matters in some op (e.g., all-to-all).
+ repeated int64 replica_ids = 1;
+}
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 165f6f5914..02af71f8a3 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -13,6 +13,79 @@ arbitrary-dimensional array. For convenience, special cases have more specific
and familiar names; for example a *vector* is a 1-dimensional array and a
*matrix* is a 2-dimensional array.
+## AllToAll
+
+See also
+[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Alltoall is a collective operation that sends data from all cores to all cores.
+It has two phases:
+
+1. the scatter phase. On each core, the operand is split into `split_count`
+ number of blocks along the `split_dimensions`, and the blocks are scatterd
+ to all cores, e.g., the ith block is send to the ith core.
+2. the gather phase. Each core concatenates the received blocks along the
+ `concat_dimension`.
+
+The participating cores can be configured by:
+
+- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty,
+ all replicas belong to one group in the order of 0 - (n-1). Alltoall will be
+ applied within subgroups in the specified order. For example, replica
+ groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica
+ 1, 2, 3, and in the gather phase, the received blocks will be concatenated
+ in the order of 1, 2, 3; another Alltoall will be applied within replica 4,
+ 5, 0, and the concatenation order is 4, 5, 0.
+
+Prerequisites:
+
+- The dimension size of the operand on the split_dimension is divisible by
+ split_count.
+- The operand's shape is not tuple.
+
+<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
+replica_groups)` </b>
+
+
+| Arguments | Type | Semantics |
+| ------------------ | --------------------- | ------------------------------- |
+| `operand` | `XlaOp` | n dimensional input array |
+| `split_dimension` | `int64` | A value in the interval `[0, |
+: : : n)` that names the dimension :
+: : : along which the operand is :
+: : : split :
+| `concat_dimension` | `int64` | a value in the interval `[0, |
+: : : n)` that names the dimension :
+: : : along which the split blocks :
+: : : are concatenated :
+| `split_count` | `int64` | the number of cores that |
+: : : participate this operation. If :
+: : : `replica_groups` is empty, this :
+: : : should be the number of :
+: : : replicas; otherwise, this :
+: : : should be equal to the number :
+: : : of replicas in each group. :
+| `replica_groups` | `ReplicaGroup` vector | each group contains a list of |
+: : : replica id. :
+
+Below shows an example of Alltoall.
+
+```
+XlaBuilder b("alltoall");
+auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
+AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
+```
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/xla/ops_alltoall.png">
+</div>
+
+In this example, there are 4 cores participating the Alltoall. On each core, the
+operand is split into 4 parts along dimension 0, so each part has shape
+f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
+the received parts along dimension 1, in the order or core 0-4. So the output on
+each core has shape f32[16,4].
+
## BatchNormGrad
See also