diff options
author | 2018-06-12 13:57:28 -0700 | |
---|---|---|
committer | 2018-06-12 14:00:55 -0700 | |
commit | 34b071f6b6a14bd4c8d5c30156c1670496b85f04 (patch) | |
tree | b413892245024a28e63adf88078ad264f1cadd16 /tensorflow | |
parent | 688a09dc6b70a81cae12a7e263515964311f8d86 (diff) |
Support subgroup CrossReplicaSum
PiperOrigin-RevId: 200275384
Diffstat (limited to 'tensorflow')
10 files changed, 108 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 5e17cc4dfb..ae8fbdb2dc 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1611,7 +1611,9 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, }); } -XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> replica_group_ids) { return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1619,7 +1621,7 @@ XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); - return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{}, + return CrossReplicaSum(operand, computation, replica_group_ids, /*channel_id=*/tensorflow::gtl::nullopt); }); } @@ -1629,7 +1631,7 @@ XlaOp XlaBuilder::CrossReplicaSum( tensorflow::gtl::ArraySlice<int64> replica_group_ids, const tensorflow::gtl::optional<ChannelHandle>& channel_id) { return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { - if (!replica_group_ids.empty() || channel_id.has_value()) { + if (channel_id.has_value()) { return Unimplemented( "replica_group_ids and channel_id and is not supported in AllReduce"); } @@ -1639,6 +1641,9 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } AddCalledComputation(computation, &instr); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 532cae0148..0329e42ed1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -528,9 +528,12 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - XlaOp CrossReplicaSum(const XlaOp& operand); + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 7fd1e733e9..f7b4c1405d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum)); + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 9926661dd3..830f26422b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -250,8 +250,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, - reduction)); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1f7c1cffd3..e201359d3d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,6 +145,7 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; + repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 28b6d6aefd..a9e73d3a77 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -298,6 +298,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->channel_name_ = proto.channel_name(); instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + for (int64 replica_group_id : proto.replica_group_ids()) { + instruction->replica_group_ids_.push_back(replica_group_id); + } + return std::move(instruction); } @@ -528,9 +532,9 @@ HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, tensorflow::gtl::ArraySlice<int64> replica_group_ids, + tensorflow::StringPiece barrier, const tensorflow::gtl::optional<int64>& channel_id) { // TODO(b/79737069): Remove the CHECK when supported. - CHECK(replica_group_ids.empty()); CHECK(!channel_id.has_value()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); @@ -538,6 +542,9 @@ HloInstruction::CreateCrossReplicaSum( instruction->AppendOperand(operand); } instruction->called_computations_.push_back(reduce_computation); + instruction->replica_group_ids_.assign(replica_group_ids.begin(), + replica_group_ids.end()); + instruction->cross_replica_sum_barrier_ = std::string(barrier); return instruction; } @@ -1138,7 +1145,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( *dot_dimension_numbers_); break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); + clone = + CreateCrossReplicaSum(shape, new_operands, to_apply(), + replica_group_ids_, cross_replica_sum_barrier_); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1507,7 +1516,9 @@ bool HloInstruction::IdenticalSlowPath( other.padding_config()); case HloOpcode::kCall: case HloOpcode::kCrossReplicaSum: - return eq_computations(to_apply(), other.to_apply()); + return replica_group_ids() == other.replica_group_ids() && + cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && + eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: if ((window_ == nullptr) != (other.window_ == nullptr) || (window_ != nullptr && @@ -2086,6 +2097,14 @@ std::vector<string> HloInstruction::ExtraAttributesToString( "\", entry=", operand_side_metadata_->ToString(), ", exit=", user_side_metadata_->ToString(), "}")); } + if (!replica_group_ids().empty()) { + extra.push_back( + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")); + } + if (!cross_replica_sum_barrier().empty()) { + extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -2173,6 +2192,9 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); + for (int64 replica_group_id : replica_group_ids_) { + proto.add_replica_group_ids(replica_group_id); + } return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 7d1ea129df..fcd175e66f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -443,7 +443,8 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + tensorflow::StringPiece barrier, const tensorflow::gtl::optional<int64>& channel_id = tensorflow::gtl::nullopt); @@ -1447,6 +1448,20 @@ class HloInstruction { void set_fusion_kind(FusionKind kind); // Old methods kept for smooth subclassing transition END. + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector<int64>& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + protected: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. @@ -1650,6 +1665,12 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // The group id of each replica for CrossReplicaSum. + std::vector<int64> replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4aa4406292..fef475380c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -588,13 +588,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCrossReplicaSum: { optional<HloComputation*> to_apply; + optional<std::vector<int64>> replica_group_ids; + optional<string> barrier; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply)); + + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "")); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "")); + } break; } case HloOpcode::kReshape: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1c5a47c875..f834d34d57 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -918,6 +918,24 @@ ENTRY CRS { )" }, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc" +} + +)" +} }); // clang-format on } diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 5887c3d88b..f7e116bf0f 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -581,12 +581,21 @@ Computes a sum across replicas. Arguments | Type | Semantics --------- | ------- | ----------------------------- `operand` | `XlaOp` | Array to sum across replicas. +| `replica_group_ids` | `int64` vector | Group ID for each replica. | The output shape is the same as the input shape. For example, if there are two replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)` respectively on the two replicas, then the output value from this op will be `(4.0, 7.75)` on both replicas. +`replica_group_ids` identifies the group ID of each replica. The group ID must +either be empty (all replicas belong to a single group), or contain the same +number of elements as the number of replicas. For example, if +`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are +four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of +each subgroup *must* be identical, so, for example, using: +`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid. + Computing the result of CrossReplicaSum requires having one input from each replica, so if one replica executes a CrossReplicaSum node more times than another, then the former replica will wait forever. Since the replicas are all |