diff options
author | HyoukJoong Lee <hyouklee@google.com> | 2018-08-22 09:45:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 09:49:44 -0700 |
commit | 75ef667e5ab2075ee7a451cb2e46d6fa12dd9e59 (patch) | |
tree | 597d7409a50d4361ba58c4808ca9ca07560b5170 /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | 567189980f7a1c2aa09a5170bd8d01a6ec37d303 (diff) |
Change subgroup interface for CrossReplicaSum
PiperOrigin-RevId: 209780185
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 44 |
1 files changed, 25 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ede55510d3..beef96476c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -293,6 +293,20 @@ class HloParser { missing_instruction_hook_; }; +// Creates replica groups from the provided nested array. groups[i] represents +// the replica ids for group 'i'. +std::vector<ReplicaGroup> CreateReplicaGroups( + tensorflow::gtl::ArraySlice<std::vector<int64>> groups) { + std::vector<ReplicaGroup> replica_groups; + absl::c_transform(groups, std::back_inserter(replica_groups), + [](const std::vector<int64>& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + return replica_groups; +} + bool HloParser::Error(LocTy loc, StringPiece msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; @@ -637,31 +651,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional<std::vector<std::vector<int64>>> tmp_groups; optional<HloComputation*> to_apply; optional<std::vector<int64>> replica_group_ids; optional<string> barrier; optional<int64> all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - attrs["replica_group_ids"] = { - /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "", all_reduce_id)); - } else { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "", - all_reduce_id)); + std::vector<ReplicaGroup> replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); } + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, replica_groups, + barrier ? *barrier : "", all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -675,13 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } std::vector<ReplicaGroup> replica_groups; if (tmp_groups) { - absl::c_transform( - *tmp_groups, std::back_inserter(replica_groups), - [](const std::vector<int64>& ids) { - ReplicaGroup group; - *group.mutable_replica_ids() = {ids.begin(), ids.end()}; - return group; - }); + replica_groups = CreateReplicaGroups(*tmp_groups); } instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( shape, operands, replica_groups, barrier ? *barrier : "")); |