aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2018-08-22 09:45:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 09:49:44 -0700
commit75ef667e5ab2075ee7a451cb2e46d6fa12dd9e59 (patch)
tree597d7409a50d4361ba58c4808ca9ca07560b5170 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent567189980f7a1c2aa09a5170bd8d01a6ec37d303 (diff)
Change subgroup interface for CrossReplicaSum
PiperOrigin-RevId: 209780185
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index dbafa35b2a..36fac4a266 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -300,10 +300,10 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const std::vector<ReplicaGroup>& replica_groups,
tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id)
: HloInstruction(HloOpcode::kCrossReplicaSum, shape),
- replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
+ replica_groups_(replica_groups),
cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
all_reduce_id_(all_reduce_id) {
for (auto operand : operands) {
@@ -314,9 +314,8 @@ HloAllReduceInstruction::HloAllReduceInstruction(
HloInstructionProto HloAllReduceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- for (int64 i : replica_group_ids_) {
- proto.add_replica_group_ids(i);
- }
+ *proto.mutable_replica_groups() = {replica_groups_.begin(),
+ replica_groups_.end()};
// Proto3 is so sad.
if (all_reduce_id_) {
proto.set_all_reduce_id(*all_reduce_id_);
@@ -327,8 +326,14 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
- std::vector<string> result = {
- StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")};
+ std::vector<string> result;
+ std::vector<string> replica_group_str;
+ for (const ReplicaGroup& group : replica_groups()) {
+ replica_group_str.push_back(
+ StrCat("{", Join(group.replica_ids(), ","), "}"));
+ }
+ result.push_back(
+ StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
if (!cross_replica_sum_barrier().empty()) {
result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
}
@@ -343,7 +348,11 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
- return replica_group_ids() == casted_other.replica_group_ids() &&
+ return ContainersEqual(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return ContainersEqual(a.replica_ids(),
+ b.replica_ids());
+ }) &&
eq_computations(to_apply(), casted_other.to_apply()) &&
cross_replica_sum_barrier() ==
casted_other.cross_replica_sum_barrier() &&
@@ -356,7 +365,7 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>(
- shape, new_operands, to_apply(), replica_group_ids(),
+ shape, new_operands, to_apply(), replica_groups(),
cross_replica_sum_barrier(), all_reduce_id());
}