From 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Aug 2018 17:10:33 -0700 Subject: [XLA] Add the xla interface for AllToAll. PiperOrigin-RevId: 207971529 --- .../performance/xla/operation_semantics.md | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) (limited to 'tensorflow/docs_src') diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 165f6f5914..02af71f8a3 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -13,6 +13,79 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AllToAll + +See also +[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Alltoall is a collective operation that sends data from all cores to all cores. +It has two phases: + +1. the scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimensions`, and the blocks are scatterd + to all cores, e.g., the ith block is send to the ith core. +2. the gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + +The participating cores can be configured by: + +- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, + all replicas belong to one group in the order of 0 - (n-1). Alltoall will be + applied within subgroups in the specified order. For example, replica + groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica + 1, 2, 3, and in the gather phase, the received blocks will be concatenated + in the order of 1, 2, 3; another Alltoall will be applied within replica 4, + 5, 0, and the concatenation order is 4, 5, 0. + +Prerequisites: + +- The dimension size of the operand on the split_dimension is divisible by + split_count. +- The operand's shape is not tuple. + + `AllToAll(operand, split_dimension, concat_dimension, split_count, +replica_groups)` + + +| Arguments | Type | Semantics | +| ------------------ | --------------------- | ------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `split_dimension` | `int64` | A value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the operand is : +: : : split : +| `concat_dimension` | `int64` | a value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the split blocks : +: : : are concatenated : +| `split_count` | `int64` | the number of cores that | +: : : participate this operation. If : +: : : `replica_groups` is empty, this : +: : : should be the number of : +: : : replicas; otherwise, this : +: : : should be equal to the number : +: : : of replicas in each group. : +| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | +: : : replica id. : + +Below shows an example of Alltoall. + +``` +XlaBuilder b("alltoall"); +auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); +AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); +``` + +
+ +
+ +In this example, there are 4 cores participating the Alltoall. On each core, the +operand is split into 4 parts along dimension 0, so each part has shape +f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates +the received parts along dimension 1, in the order or core 0-4. So the output on +each core has shape f32[16,4]. + ## BatchNormGrad See also -- cgit v1.2.3