aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 18:23:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 18:28:20 -0700
commitb7b3f571728898c6d822aa1252d20bced15b989d (patch)
tree2c0752b0a588f008d7139105315a3a1ecc14edee
parentef9db4a7ea59fa966659be1466bdbfd500b73cc6 (diff)
[XLA] Extract HloCollectiveInstruction as superclass of all collective instructions.
This will make adding collective instructions slightly easier. PiperOrigin-RevId: 209864869
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc113
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h50
3 files changed, 79 insertions, 96 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index cf1845c8fe..668ed9d6c3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -3183,26 +3183,16 @@ const string& HloInstruction::outfeed_config() const {
}
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
- return Cast<HloAllReduceInstruction>(this)->replica_groups();
- }
- return Cast<HloAllToAllInstruction>(this)->replica_groups();
+ return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
string HloInstruction::cross_replica_sum_barrier() const {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
- }
- return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
barrier);
- }
- return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier(
- barrier);
}
absl::optional<int64> HloInstruction::all_reduce_id() const {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 345ca0053a..2a99d4d7c4 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -297,34 +297,24 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
-HloAllReduceInstruction::HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* reduce_computation,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id)
- : HloInstruction(HloOpcode::kCrossReplicaSum, shape),
- replica_groups_(replica_groups),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
- all_reduce_id_(all_reduce_id) {
+HloCollectiveInstruction::HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups)
+ : HloInstruction(opcode, shape), replica_groups_(replica_groups) {
for (auto operand : operands) {
AppendOperand(operand);
}
- AppendComputation(reduce_computation);
}
-HloInstructionProto HloAllReduceInstruction::ToProto() const {
+HloInstructionProto HloCollectiveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_replica_groups() = {replica_groups_.begin(),
replica_groups_.end()};
- // Proto3 is so sad.
- if (all_reduce_id_) {
- proto.set_all_reduce_id(*all_reduce_id_);
- }
- proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
return proto;
}
-std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
+std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
std::vector<string> result;
std::vector<string> replica_group_str;
@@ -334,6 +324,48 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
}
result.push_back(
StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+ return result;
+}
+
+bool HloCollectiveInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectiveInstruction&>(other);
+ return ContainersEqual(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return ContainersEqual(a.replica_ids(),
+ b.replica_ids());
+ });
+}
+
+HloAllReduceInstruction::HloAllReduceInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id)
+ : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands,
+ replica_groups),
+ cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
+ all_reduce_id_(all_reduce_id) {
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloAllReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloCollectiveInstruction::ToProto();
+ // Proto3 is so sad.
+ if (all_reduce_id_) {
+ proto.set_all_reduce_id(*all_reduce_id_);
+ }
+ proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
+ return proto;
+}
+
+std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> result =
+ HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (!cross_replica_sum_barrier().empty()) {
result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
}
@@ -348,11 +380,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- }) &&
+ return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
eq_computations(to_apply(), casted_other.to_apply()) &&
cross_replica_sum_barrier() ==
casted_other.cross_replica_sum_barrier() &&
@@ -372,24 +400,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
HloAllToAllInstruction::HloAllToAllInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
const std::vector<ReplicaGroup>& replica_groups)
- : HloInstruction(HloOpcode::kAllToAll, shape),
- replica_groups_(replica_groups) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-bool HloAllToAllInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- });
-}
+ : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
+ replica_groups) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
@@ -400,27 +412,6 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
replica_groups());
}
-std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
- const HloPrintOptions& options) const {
- std::vector<string> result;
- std::vector<string> replica_group_str;
- for (const ReplicaGroup& group : replica_groups()) {
- replica_group_str.push_back(
- StrCat("{", Join(group.replica_ids(), ","), "}"));
- }
- result.push_back(
- StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
-
- return result;
-}
-
-HloInstructionProto HloAllToAllInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_replica_groups() = {replica_groups_.begin(),
- replica_groups_.end()};
- return proto;
-}
-
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 755e560151..19e98c6fb4 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -218,7 +218,31 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
HloCloneContext* context) const override;
};
-class HloAllReduceInstruction : public HloInstruction {
+class HloCollectiveInstruction : public HloInstruction {
+ public:
+ const std::vector<ReplicaGroup>& replica_groups() const {
+ return replica_groups_;
+ }
+
+ protected:
+ explicit HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
+
+ HloInstructionProto ToProto() const override;
+
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ std::vector<ReplicaGroup> replica_groups_;
+};
+
+class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
@@ -227,10 +251,6 @@ class HloAllReduceInstruction : public HloInstruction {
tensorflow::StringPiece barrier,
const absl::optional<int64>& all_reduce_id);
- const std::vector<ReplicaGroup>& replica_groups() const {
- return replica_groups_;
- }
-
// Returns the barrier config used for the CrossReplicaSum implementation of
// each backend.
string cross_replica_sum_barrier() const {
@@ -259,9 +279,6 @@ class HloAllReduceInstruction : public HloInstruction {
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const override;
- // The replica ids of each subgroup for CrossReplicaSum op.
- std::vector<ReplicaGroup> replica_groups_;
-
// The string representation of the barrier config used for CrossReplicaSum.
string cross_replica_sum_barrier_;
@@ -271,33 +288,18 @@ class HloAllReduceInstruction : public HloInstruction {
absl::optional<int64> all_reduce_id_;
};
-class HloAllToAllInstruction : public HloInstruction {
+class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
const std::vector<ReplicaGroup>& replica_groups);
- const std::vector<ReplicaGroup>& replica_groups() const {
- return replica_groups_;
- }
-
- HloInstructionProto ToProto() const override;
-
private:
- std::vector<string> ExtraAttributesToStringImpl(
- const HloPrintOptions& options) const override;
- bool IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const override;
-
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const override;
-
- std::vector<ReplicaGroup> replica_groups_;
};
class HloReverseInstruction : public HloInstruction {