diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 18:23:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 18:28:20 -0700 |
commit | b7b3f571728898c6d822aa1252d20bced15b989d (patch) | |
tree | 2c0752b0a588f008d7139105315a3a1ecc14edee | |
parent | ef9db4a7ea59fa966659be1466bdbfd500b73cc6 (diff) |
[XLA] Extract HloCollectiveInstruction as superclass of all collective instructions.
This will make adding collective instructions slightly easier.
PiperOrigin-RevId: 209864869
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 113 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 50 |
3 files changed, 79 insertions, 96 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index cf1845c8fe..668ed9d6c3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3183,26 +3183,16 @@ const string& HloInstruction::outfeed_config() const { } const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast<HloAllReduceInstruction>(this)->replica_groups(); - } - return Cast<HloAllToAllInstruction>(this)->replica_groups(); + return Cast<HloCollectiveInstruction>(this)->replica_groups(); } string HloInstruction::cross_replica_sum_barrier() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); - } - return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - if (opcode() == HloOpcode::kCrossReplicaSum) { return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( barrier); - } - return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier( - barrier); } absl::optional<int64> HloInstruction::all_reduce_id() const { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 345ca0053a..2a99d4d7c4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -297,34 +297,24 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer()); } -HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloComputation* reduce_computation, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), - all_reduce_id_(all_reduce_id) { +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } - AppendComputation(reduce_computation); } -HloInstructionProto HloAllReduceInstruction::ToProto() const { +HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; - // Proto3 is so sad. - if (all_reduce_id_) { - proto.set_all_reduce_id(*all_reduce_id_); - } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } -std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( +std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { std::vector<string> result; std::vector<string> replica_group_str; @@ -334,6 +324,48 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( } result.push_back( StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectiveInstruction&>(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* reduce_computation, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector<string> result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -348,11 +380,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -372,24 +400,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, const std::vector<ReplicaGroup>& replica_groups) - : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -bool HloAllToAllInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const { - const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }); -} + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} std::unique_ptr<HloInstruction> HloAllToAllInstruction::CloneWithNewOperandsImpl( @@ -400,27 +412,6 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( replica_groups()); } -std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector<string> result; - std::vector<string> replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); - - return result; -} - -HloInstructionProto HloAllToAllInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_replica_groups() = {replica_groups_.begin(), - replica_groups_.end()}; - return proto; -} - HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 755e560151..19e98c6fb4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -218,7 +218,31 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { HloCloneContext* context) const override; }; -class HloAllReduceInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloInstruction { + public: + const std::vector<ReplicaGroup>& replica_groups() const { + return replica_groups_; + } + + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups); + + HloInstructionProto ToProto() const override; + + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + + std::vector<ReplicaGroup> replica_groups_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, @@ -227,10 +251,6 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id); - const std::vector<ReplicaGroup>& replica_groups() const { - return replica_groups_; - } - // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. string cross_replica_sum_barrier() const { @@ -259,9 +279,6 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const override; - // The replica ids of each subgroup for CrossReplicaSum op. - std::vector<ReplicaGroup> replica_groups_; - // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; @@ -271,33 +288,18 @@ class HloAllReduceInstruction : public HloInstruction { absl::optional<int64> all_reduce_id_; }; -class HloAllToAllInstruction : public HloInstruction { +class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand, const std::vector<ReplicaGroup>& replica_groups); - const std::vector<ReplicaGroup>& replica_groups() const { - return replica_groups_; - } - - HloInstructionProto ToProto() const override; - private: - std::vector<string> ExtraAttributesToStringImpl( - const HloPrintOptions& options) const override; - bool IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const override; - // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const override; - - std::vector<ReplicaGroup> replica_groups_; }; class HloReverseInstruction : public HloInstruction { |