aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 17:10:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 17:15:13 -0700
commit3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch)
tree926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/hlo_instruction.h
parentaacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff)
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e722086732..3c575ae6ea 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -449,6 +449,26 @@ class HloInstruction {
tensorflow::StringPiece barrier,
const tensorflow::gtl::optional<int64>& all_reduce_id);
+ // This op handles the communication of an Alltoall operation. On each core,
+ // the operands are N ops in the same shape, where N is the number of cores
+ // participating the Alltoall. Then the N operands are scattered to N cores,
+ // e.g., the ith operand is sent to the ith core. Then each core gathers the
+ // received data into a tuple.
+ //
+ // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
+ // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall
+ // will be applied within subgroups in the specified order. For example,
+ // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied
+ // within replica 1, 2, 3, and in the gather phase, the received blocks will
+ // be concatenated in the order of 1, 2, 3; another Alltoall will be applied
+ // within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
+ //
+ // TODO(b/110096724): This is NOT YET ready to use.
+ static std::unique_ptr<HloInstruction> CreateAllToAll(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier);
+
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
@@ -1414,6 +1434,9 @@ class HloInstruction {
// Delegates to HloAllReduceInstruction::replica_group_ids.
const std::vector<int64>& replica_group_ids() const;
+ // Delegates to HloAllToAllInstruction::replica_groups.
+ const std::vector<ReplicaGroup>& replica_groups() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);