aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc9
4 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index e201359d3d..d241791060 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -145,12 +145,16 @@ message HloInstructionProto {
repeated int64 operand_ids = 36;
repeated int64 control_predecessor_ids = 37;
repeated int64 called_computation_ids = 38;
- repeated int64 replica_group_ids = 44;
xla.OpSharding sharding = 40;
// Backend configuration for the instruction. Has backend-specific meaning.
string backend_config = 43;
+
+ // Cross Replica Sum fields.
+ repeated int64 replica_group_ids = 44;
+ int64 all_reduce_id = 45;
+ string cross_replica_sum_barrier = 46;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8bedd2a865..8f89b6f255 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -261,12 +261,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
[&instruction_map](int64 operand_id) {
return instruction_map.at(operand_id);
});
+ tensorflow::gtl::optional<int64> all_reduce_id;
+ if (proto.all_reduce_id() > 0) {
+ all_reduce_id = proto.all_reduce_id();
+ }
instruction = CreateCrossReplicaSum(
proto.shape(), all_operands, computations(0),
/*replica_group_ids=*/
std::vector<int64>(proto.replica_group_ids().begin(),
proto.replica_group_ids().end()),
- /*barrier=*/"");
+ /*barrier=*/proto.cross_replica_sum_barrier(),
+ /*all_reduce_id=*/all_reduce_id);
break;
}
default: {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 5871a6605f..1ebc4c936a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -280,7 +280,7 @@ HloAllReduceInstruction::HloAllReduceInstruction(
cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
all_reduce_id_(all_reduce_id) {
// TODO(b/79737069): Remove the CHECK when supported.
- CHECK(!all_reduce_id_.has_value());
+ CHECK(!all_reduce_id_);
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -292,7 +292,11 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
for (int64 i : replica_group_ids_) {
proto.add_replica_group_ids(i);
}
- // TODO(b/79737069): handle barrier and all_reduce_id.
+ // Proto3 is so sad.
+ if (all_reduce_id_) {
+ proto.set_all_reduce_id(*all_reduce_id_);
+ }
+ proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
return proto;
}
@@ -303,7 +307,7 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
if (!cross_replica_sum_barrier().empty()) {
result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
}
- if (all_reduce_id_.has_value()) {
+ if (all_reduce_id_) {
result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
}
return result;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index fef475380c..daa3bc4232 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -590,24 +590,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<string> barrier;
+ optional<int64> all_reduce_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_group_ids"] = {
/*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
+ attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
+ &all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
-
if (replica_group_ids) {
instruction =
builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
shape, operands, *to_apply, *replica_group_ids,
- barrier ? *barrier : ""));
+ barrier ? *barrier : "", all_reduce_id));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, {}, barrier ? *barrier : ""));
+ shape, operands, *to_apply, {}, barrier ? *barrier : "",
+ all_reduce_id));
}
break;
}