diff options
author | HyoukJoong Lee <hyouklee@google.com> | 2018-08-22 09:45:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 09:49:44 -0700 |
commit | 75ef667e5ab2075ee7a451cb2e46d6fa12dd9e59 (patch) | |
tree | 597d7409a50d4361ba58c4808ca9ca07560b5170 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 567189980f7a1c2aa09a5170bd8d01a6ec37d303 (diff) |
Change subgroup interface for CrossReplicaSum
PiperOrigin-RevId: 209780185
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index dbafa35b2a..36fac4a266 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -300,10 +300,10 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const std::vector<ReplicaGroup>& replica_groups, tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + replica_groups_(replica_groups), cross_replica_sum_barrier_(barrier.begin(), barrier.end()), all_reduce_id_(all_reduce_id) { for (auto operand : operands) { @@ -314,9 +314,8 @@ HloAllReduceInstruction::HloAllReduceInstruction( HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - for (int64 i : replica_group_ids_) { - proto.add_replica_group_ids(i); - } + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; // Proto3 is so sad. if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); @@ -327,8 +326,14 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { - std::vector<string> result = { - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + std::vector<string> result; + std::vector<string> replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -343,7 +348,11 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); - return replica_group_ids() == casted_other.replica_group_ids() && + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -356,7 +365,7 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* /*context*/) const { return absl::make_unique<HloAllReduceInstruction>( - shape, new_operands, to_apply(), replica_group_ids(), + shape, new_operands, to_apply(), replica_groups(), cross_replica_sum_barrier(), all_reduce_id()); } |