diff options
Diffstat (limited to 'tensorflow/docs_src/performance/xla/operation_semantics.md')
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 87 |
1 files changed, 80 insertions, 7 deletions
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index edc777a3c7..02af71f8a3 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -13,6 +13,79 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AllToAll + +See also +[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Alltoall is a collective operation that sends data from all cores to all cores. +It has two phases: + +1. the scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimensions`, and the blocks are scatterd + to all cores, e.g., the ith block is send to the ith core. +2. the gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + +The participating cores can be configured by: + +- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, + all replicas belong to one group in the order of 0 - (n-1). Alltoall will be + applied within subgroups in the specified order. For example, replica + groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica + 1, 2, 3, and in the gather phase, the received blocks will be concatenated + in the order of 1, 2, 3; another Alltoall will be applied within replica 4, + 5, 0, and the concatenation order is 4, 5, 0. + +Prerequisites: + +- The dimension size of the operand on the split_dimension is divisible by + split_count. +- The operand's shape is not tuple. + +<b> `AllToAll(operand, split_dimension, concat_dimension, split_count, +replica_groups)` </b> + + +| Arguments | Type | Semantics | +| ------------------ | --------------------- | ------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `split_dimension` | `int64` | A value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the operand is : +: : : split : +| `concat_dimension` | `int64` | a value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the split blocks : +: : : are concatenated : +| `split_count` | `int64` | the number of cores that | +: : : participate this operation. If : +: : : `replica_groups` is empty, this : +: : : should be the number of : +: : : replicas; otherwise, this : +: : : should be equal to the number : +: : : of replicas in each group. : +| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | +: : : replica id. : + +Below shows an example of Alltoall. + +``` +XlaBuilder b("alltoall"); +auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); +AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); +``` + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../../images/xla/ops_alltoall.png"> +</div> + +In this example, there are 4 cores participating the Alltoall. On each core, the +operand is split into 4 parts along dimension 0, so each part has shape +f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates +the received parts along dimension 1, in the order or core 0-4. So the output on +each core has shape f32[16,4]. + ## BatchNormGrad See also @@ -270,7 +343,7 @@ Clamp(min, operand, max) = s32[3]{0, 5, 6}; See also [`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) -and the @{tf.reshape} operation. +and the `tf.reshape` operation. Collapses dimensions of an array into one dimension. @@ -291,7 +364,7 @@ same position in the dimension sequence as those they replace, with the new dimension size equal to the product of original dimension sizes. The lowest dimension number in `dimensions` is the slowest varying dimension (most major) in the loop nest which collapses these dimension, and the highest dimension -number is fastest varying (most minor). See the @{tf.reshape} operator +number is fastest varying (most minor). See the `tf.reshape` operator if more general collapse ordering is needed. For example, let v be an array of 24 elements: @@ -490,8 +563,8 @@ array. The holes are filled with a no-op value, which for convolution means zeroes. Dilation of the rhs is also called atrous convolution. For more details, see -@{tf.nn.atrous_conv2d}. Dilation of the lhs is also called transposed -convolution. For more details, see @{tf.nn.conv2d_transpose}. +`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed +convolution. For more details, see `tf.nn.conv2d_transpose`. The output shape has these dimensions, in this order: @@ -1270,7 +1343,7 @@ let t: (f32[10], s32) = tuple(v, s); let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32. ``` -See also @{tf.tuple}. +See also `tf.tuple`. ## Infeed @@ -1847,7 +1920,7 @@ tensor `operand`, with several slices (at indices specified by `update_computation`. See also -[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). <b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b> @@ -2250,7 +2323,7 @@ element types. ## Transpose -See also the @{tf.reshape} operation. +See also the `tf.reshape` operation. <b>`Transpose(operand)`</b> |