diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 9 |
4 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e201359d3d..d241791060 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,12 +145,16 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; - repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; + + // Cross Replica Sum fields. + repeated int64 replica_group_ids = 44; + int64 all_reduce_id = 45; + string cross_replica_sum_barrier = 46; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8bedd2a865..8f89b6f255 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -261,12 +261,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( [&instruction_map](int64 operand_id) { return instruction_map.at(operand_id); }); + tensorflow::gtl::optional<int64> all_reduce_id; + if (proto.all_reduce_id() > 0) { + all_reduce_id = proto.all_reduce_id(); + } instruction = CreateCrossReplicaSum( proto.shape(), all_operands, computations(0), /*replica_group_ids=*/ std::vector<int64>(proto.replica_group_ids().begin(), proto.replica_group_ids().end()), - /*barrier=*/""); + /*barrier=*/proto.cross_replica_sum_barrier(), + /*all_reduce_id=*/all_reduce_id); break; } default: { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 5871a6605f..1ebc4c936a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -280,7 +280,7 @@ HloAllReduceInstruction::HloAllReduceInstruction( cross_replica_sum_barrier_(barrier.begin(), barrier.end()), all_reduce_id_(all_reduce_id) { // TODO(b/79737069): Remove the CHECK when supported. - CHECK(!all_reduce_id_.has_value()); + CHECK(!all_reduce_id_); for (auto operand : operands) { AppendOperand(operand); } @@ -292,7 +292,11 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { for (int64 i : replica_group_ids_) { proto.add_replica_group_ids(i); } - // TODO(b/79737069): handle barrier and all_reduce_id. + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } @@ -303,7 +307,7 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } - if (all_reduce_id_.has_value()) { + if (all_reduce_id_) { result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); } return result; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index fef475380c..daa3bc4232 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -590,24 +590,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<HloComputation*> to_apply; optional<std::vector<int64>> replica_group_ids; optional<string> barrier; + optional<int64> all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; attrs["replica_group_ids"] = { /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, + &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "")); + barrier ? *barrier : "", all_reduce_id)); } else { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "")); + shape, operands, *to_apply, {}, barrier ? *barrier : "", + all_reduce_id)); } break; } |