diff options
author | 2018-08-22 18:23:42 -0700 | |
---|---|---|
committer | 2018-08-22 18:28:20 -0700 | |
commit | b7b3f571728898c6d822aa1252d20bced15b989d (patch) | |
tree | 2c0752b0a588f008d7139105315a3a1ecc14edee /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | ef9db4a7ea59fa966659be1466bdbfd500b73cc6 (diff) |
[XLA] Extract HloCollectiveInstruction as superclass of all collective instructions.
This will make adding collective instructions slightly easier.
PiperOrigin-RevId: 209864869
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 113 |
1 files changed, 52 insertions, 61 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 345ca0053a..2a99d4d7c4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -297,34 +297,24 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer()); } -HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* reduce_computation, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), - all_reduce_id_(all_reduce_id) { +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } - AppendComputation(reduce_computation); } -HloInstructionProto HloAllReduceInstruction::ToProto() const { +HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; - // Proto3 is so sad. - if (all_reduce_id_) { - proto.set_all_reduce_id(*all_reduce_id_); - } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } -std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( +std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { std::vector<string> result; std::vector<string> replica_group_str; @@ -334,6 +324,48 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( } result.push_back( StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectiveInstruction&>(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* reduce_computation, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector<string> result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -348,11 +380,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -372,24 +400,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, const std::vector<ReplicaGroup>& replica_groups) - : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -bool HloAllToAllInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const { - const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }); -} + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} std::unique_ptr<HloInstruction> HloAllToAllInstruction::CloneWithNewOperandsImpl( @@ -400,27 +412,6 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( replica_groups()); } -std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector<string> result; - std::vector<string> replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); - - return result; -} - -HloInstructionProto HloAllToAllInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_replica_groups() = {replica_groups_.begin(), - replica_groups_.end()}; - return proto; -} - HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) |