aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 17:10:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 17:15:13 -0700
commit3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch)
tree926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/hlo_instructions.cc
parentaacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff)
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 1d71a74c40..1de5032670 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -359,6 +359,67 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
cross_replica_sum_barrier(), all_reduce_id());
}
+HloAllToAllInstruction::HloAllToAllInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier)
+ : HloInstruction(HloOpcode::kAllToAll, shape),
+ replica_groups_(replica_groups),
+ cross_replica_sum_barrier_(barrier.begin(), barrier.end()) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+bool HloAllToAllInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
+ return ContainersEqual(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return ContainersEqual(a.replica_ids(),
+ b.replica_ids());
+ }) &&
+ cross_replica_sum_barrier() ==
+ casted_other.cross_replica_sum_barrier();
+}
+
+std::unique_ptr<HloInstruction>
+HloAllToAllInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return MakeUnique<HloAllToAllInstruction>(
+ shape, new_operands, replica_groups(), cross_replica_sum_barrier());
+}
+
+std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> result;
+ std::vector<string> replica_group_str;
+ for (const ReplicaGroup& group : replica_groups()) {
+ replica_group_str.push_back(
+ StrCat("{", Join(group.replica_ids(), ","), "}"));
+ }
+ result.push_back(
+ StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+
+ if (!cross_replica_sum_barrier().empty()) {
+ result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+ }
+
+ return result;
+}
+
+HloInstructionProto HloAllToAllInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_replica_groups() = {replica_groups_.begin(),
+ replica_groups_.end()};
+ proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
+ return proto;
+}
+
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)