aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2018-08-22 09:45:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 09:49:44 -0700
commit75ef667e5ab2075ee7a451cb2e46d6fa12dd9e59 (patch)
tree597d7409a50d4361ba58c4808ca9ca07560b5170 /tensorflow/compiler/xla/service/hlo_parser.cc
parent567189980f7a1c2aa09a5170bd8d01a6ec37d303 (diff)
Change subgroup interface for CrossReplicaSum
PiperOrigin-RevId: 209780185
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc44
1 files changed, 25 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ede55510d3..beef96476c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -293,6 +293,20 @@ class HloParser {
missing_instruction_hook_;
};
+// Creates replica groups from the provided nested array. groups[i] represents
+// the replica ids for group 'i'.
+std::vector<ReplicaGroup> CreateReplicaGroups(
+ tensorflow::gtl::ArraySlice<std::vector<int64>> groups) {
+ std::vector<ReplicaGroup> replica_groups;
+ absl::c_transform(groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
+ return replica_groups;
+}
+
bool HloParser::Error(LocTy loc, StringPiece msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
@@ -637,31 +651,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
+ optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<string> barrier;
optional<int64> all_reduce_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
- attrs["replica_group_ids"] = {
- /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
+ attrs["replica_groups"] = {/*required=*/false,
+ AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
&all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- if (replica_group_ids) {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, *replica_group_ids,
- barrier ? *barrier : "", all_reduce_id));
- } else {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, {}, barrier ? *barrier : "",
- all_reduce_id));
+ std::vector<ReplicaGroup> replica_groups;
+ if (tmp_groups) {
+ replica_groups = CreateReplicaGroups(*tmp_groups);
}
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, replica_groups,
+ barrier ? *barrier : "", all_reduce_id));
break;
}
case HloOpcode::kAllToAll: {
@@ -675,13 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
- absl::c_transform(
- *tmp_groups, std::back_inserter(replica_groups),
- [](const std::vector<int64>& ids) {
- ReplicaGroup group;
- *group.mutable_replica_ids() = {ids.begin(), ids.end()};
- return group;
- });
+ replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
shape, operands, replica_groups, barrier ? *barrier : ""));