diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-08 17:10:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:15:13 -0700 |
commit | 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch) | |
tree | 926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/hlo_instruction.h | |
parent | aacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff) |
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e722086732..3c575ae6ea 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -449,6 +449,26 @@ class HloInstruction { tensorflow::StringPiece barrier, const tensorflow::gtl::optional<int64>& all_reduce_id); + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr<HloInstruction> CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups, + tensorflow::StringPiece barrier); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, @@ -1414,6 +1434,9 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::replica_group_ids. const std::vector<int64>& replica_group_ids() const; + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector<ReplicaGroup>& replica_groups() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); |