diff options
23 files changed, 522 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 073d66bcd2..b3b00e2fff 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1881,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector<ReplicaGroup>& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector<XlaOp> slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector<const Shape*> slice_shape_ptrs; + c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector<XlaOp> received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, @@ -2677,6 +2732,13 @@ XlaOp CrossReplicaSum( replica_group_ids, channel_id); } +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector<ReplicaGroup>& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 3c5f8c8d53..9403d7ca8d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -699,9 +699,9 @@ class XlaBuilder { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( @@ -710,6 +710,13 @@ class XlaBuilder { const tensorflow::gtl::optional<ChannelHandle>& channel_id = tensorflow::gtl::nullopt); + // Enqueues an operation that do an Alltoall of the operand cross cores. + // + // TODO(b/110096724): This is NOT YET ready to use. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector<ReplicaGroup>& replica_groups); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1246,6 +1253,9 @@ class XlaBuilder { const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> replica_group_ids, const tensorflow::gtl::optional<ChannelHandle>& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector<ReplicaGroup>& replica_groups); friend XlaOp SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, @@ -1832,9 +1842,9 @@ XlaOp CrossReplicaSum( // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // -// - `channel_id`: for Allreduce nodes from different models, if they have the +// - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be -// applied cross models. +// applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, @@ -1842,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, const tensorflow::gtl::optional<ChannelHandle>& channel_id = tensorflow::gtl::nullopt); +// Enqueues an operation that do an Alltoall of the operand cross cores. +// +// TODO(b/110096724): This is NOT YET ready to use. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector<ReplicaGroup>& replica_groups = {}); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index afe5be29f0..49a15ec3b4 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -305,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 9f86749125..86d57581f8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -106,6 +106,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ae8a066d62..617a5a2eb4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b93d97c11..be9098f555 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -151,8 +151,11 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; - // Cross Replica Sum fields. + // Cross replica op fields. + // TODO(b/112107579): remove replica_group_ids field and always use + // replica_groups. repeated int64 replica_group_ids = 44; + repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index a2cefd2621..1bbb0ff08e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -543,6 +543,19 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { return Status::OK(); } +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + // TODO(b/110096724): Compute correct cost here. + double flops = 0.0; + ShapeUtil::ForEachSubshape(hlo->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0a79c92f4a..193a04bea0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleHostCompute(const HloInstruction* host_compute) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bfe83cabf1..1efa6eb5bd 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1048,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 7591b99204..8690f2cdaa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -320,6 +320,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( /*all_reduce_id=*/all_reduce_id); break; } + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector<ReplicaGroup>(proto.replica_groups().begin(), + proto.replica_groups().end()), + /*barrier=*/proto.cross_replica_sum_barrier()); + break; + } case HloOpcode::kConvolution: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " @@ -671,6 +680,14 @@ HloInstruction::CreateCrossReplicaSum( all_reduce_id); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier) { + return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups, + barrier); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -1153,6 +1170,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1620,6 +1638,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -2265,6 +2284,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -3139,12 +3160,23 @@ const std::vector<int64>& HloInstruction::replica_group_ids() const { return Cast<HloAllReduceInstruction>(this)->replica_group_ids(); } +const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { + return Cast<HloAllToAllInstruction>(this)->replica_groups(); +} + string HloInstruction::cross_replica_sum_barrier() const { - return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); + } + return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( + barrier); + } + return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier( barrier); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e722086732..3c575ae6ea 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -449,6 +449,26 @@ class HloInstruction { tensorflow::StringPiece barrier, const tensorflow::gtl::optional<int64>& all_reduce_id); + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr<HloInstruction> CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, @@ -1414,6 +1434,9 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::replica_group_ids. const std::vector<int64>& replica_group_ids() const; + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector<ReplicaGroup>& replica_groups() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1d71a74c40..1de5032670 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -359,6 +359,67 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( cross_replica_sum_barrier(), all_reduce_id()); } +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier) + : HloInstruction(HloOpcode::kAllToAll, shape), + replica_groups_(replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +bool HloAllToAllInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier(); +} + +std::unique_ptr<HloInstruction> +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique<HloAllToAllInstruction>( + shape, new_operands, replica_groups(), cross_replica_sum_barrier()); +} + +std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector<string> result; + std::vector<string> replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + + return result; +} + +HloInstructionProto HloAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 9250b2b846..9586ad6673 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -273,6 +273,47 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::optional<int64> all_reduce_id_; }; +class HloAllToAllInstruction : public HloInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier); + + const std::vector<ReplicaGroup>& replica_groups() const { + return replica_groups_; + } + + // TODO(b/110096724): rename this. + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::vector<ReplicaGroup> replica_groups_; + + // The string representation of the barrier config. + string cross_replica_sum_barrier_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 88531b6f20..ec279867e5 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index de73b38dec..2a8c6ecd92 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -125,6 +125,7 @@ class HloParser { kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -205,6 +206,10 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function<bool()>& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -619,6 +624,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } + case HloOpcode::kAllToAll: { + optional<std::vector<std::vector<int64>>> tmp_groups; + optional<string> barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector<ReplicaGroup> replica_groups; + if (tmp_groups) { + c_transform(*tmp_groups, std::back_inserter(replica_groups), + [](const std::vector<int64>& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( + shape, operands, replica_groups, barrier ? *barrier : "")); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -2244,6 +2271,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector<std::vector<tensorflow::int64>> result; + auto parse_and_add_item = [&]() { + std::vector<tensorflow::int64> item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast<optional<std::vector<std::vector<tensorflow::int64>>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2586,6 +2633,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function<bool()>& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 7344679bb6..4cd21841f4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1072,6 +1072,30 @@ ENTRY CrossReplicaSumWithSubgroups { )" }, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" +} + +)" +}, // Iota { "Iota", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1a8c206aaf..3fae61f704 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -105,6 +105,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector<const Shape*> operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7feddaeabf..5a56a44f35 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -45,6 +45,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e2191aedb7..f33942d679 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -120,6 +120,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c888bbf144..a4ea2b28f4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1779,6 +1779,51 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector<int64> new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 33da323b3d..c185b0a1bd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -119,11 +119,22 @@ class ShapeInference { const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr<Shape> InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr<Shape> InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr<Shape> InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index fd784e909c..4c35e93d38 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -561,3 +561,11 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some op (e.g., all-to-all). + repeated int64 replica_ids = 1; +} diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 165f6f5914..02af71f8a3 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -13,6 +13,79 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AllToAll + +See also +[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Alltoall is a collective operation that sends data from all cores to all cores. +It has two phases: + +1. the scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimensions`, and the blocks are scatterd + to all cores, e.g., the ith block is send to the ith core. +2. the gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + +The participating cores can be configured by: + +- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, + all replicas belong to one group in the order of 0 - (n-1). Alltoall will be + applied within subgroups in the specified order. For example, replica + groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica + 1, 2, 3, and in the gather phase, the received blocks will be concatenated + in the order of 1, 2, 3; another Alltoall will be applied within replica 4, + 5, 0, and the concatenation order is 4, 5, 0. + +Prerequisites: + +- The dimension size of the operand on the split_dimension is divisible by + split_count. +- The operand's shape is not tuple. + +<b> `AllToAll(operand, split_dimension, concat_dimension, split_count, +replica_groups)` </b> + + +| Arguments | Type | Semantics | +| ------------------ | --------------------- | ------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `split_dimension` | `int64` | A value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the operand is : +: : : split : +| `concat_dimension` | `int64` | a value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the split blocks : +: : : are concatenated : +| `split_count` | `int64` | the number of cores that | +: : : participate this operation. If : +: : : `replica_groups` is empty, this : +: : : should be the number of : +: : : replicas; otherwise, this : +: : : should be equal to the number : +: : : of replicas in each group. : +| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | +: : : replica id. : + +Below shows an example of Alltoall. + +``` +XlaBuilder b("alltoall"); +auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); +AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); +``` + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../../images/xla/ops_alltoall.png"> +</div> + +In this example, there are 4 cores participating the Alltoall. On each core, the +operand is split into 4 parts along dimension 0, so each part has shape +f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates +the received parts along dimension 1, in the order or core 0-4. So the output on +each core has shape f32[16,4]. + ## BatchNormGrad See also |