diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-01 21:44:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 21:48:45 -0700 |
commit | e3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch) | |
tree | c4fe57f956fe9334f1e27342a4846250e915fe4a /tensorflow/docs_src | |
parent | 3379bae787d73d6db67d66a284bd1a076b2cbdba (diff) |
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
Diffstat (limited to 'tensorflow/docs_src')
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 5f7482f90f..3981aaaf75 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1801,6 +1801,138 @@ is implementation-defined. : : : limit of interval : | `shape` | `Shape` | Output shape of type T | +## Scatter + +The XLA scatter operation generates a result which is the value of the input +tensor `operand`, with several slices (at indices specified by +`scatter_indices`) updated with the values in `updates` using +`update_computation`. + +See also +[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). + +<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b> + +|Arguments | Type | Semantics | +|------------------|------------------------|----------------------------------| +|`operand` | `XlaOp` | Tensor to be scattered into. | +|`scatter_indices` | `XlaOp` | Tensor containing the starting | +: : : indices of the slices that must : +: : : be scattered to. : +|`updates` | `XlaOp` | Tensor containing the values that| +: : : must be used for scattering. : +|`update_computation`| `XlaComputation` | Computation to be used for | +: : : combining the existing values in : +: : : the input tensor and the updates : +: : : during scatter. This computation : +: : : should be of type `T, T -> T`. : +|`index_vector_dim`| `int64` | The dimension in | +: : : `scatter_indices` that contains : +: : : the starting indices. : +|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in | +: : : `updates` shape that are _window : +: : : dimensions_. : +|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ | +: : : that must be inserted into : +: : : `updates` shape. : +|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from | +: : : the scatter indices to the : +: : : operand index space. This array : +: : : is interpreted as mapping `i` to : +: : : `scatter_dims_to_operand_dims[i]`: +: : : . It has to be one-to-one and : +: : : total. : + +If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider +`scatter_indices` to have a trailing `1` dimension. + +We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of +dimensions in `updates` shape that are not in `update_window_dims`, in ascending +order. + +The arguments of scatter should follow these constraints: + + - `updates` tensor must be of rank `update_window_dims.size + + scatter_indices.rank - 1`. + + - Bounds of dimension `i` in `updates` must conform to the following: + - If `i` is present in `update_window_dims` (i.e. equal to + `update_window_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must not exceed the corresponding bound of `operand` + after accounting for the `inserted_window_dims` (i.e. + `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains + the bounds of `operand` with the bounds at indices + `inserted_window_dims` removed). + - If `i` is present in `update_scatter_dims` (i.e. equal to + `update_scatter_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must be equal to the corresponding bound of + `scatter_indices`, skipping `index_vector_dim` (i.e. + `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and + `scatter_indices.shape.dims`[`k+1`] otherwise). + + - `update_window_dims` must be in ascending order, not have any repeating + dimension numbers, and be in the range `[0, updates.rank)`. + + - `inserted_window_dims` must be in ascending order, not have any + repeating dimension numbers, and be in the range `[0, operand.rank)`. + + - `scatter_dims_to_operand_dims.size` must be equal to + `scatter_indices`[`index_vector_dim`], and its values must be in the range + `[0, operand.rank)`. + +For a given index `U` in the `updates` tensor, the corresponding index `I` in +the `operand` tensor into which this update has to be applied is computed as +follows: + + 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up + an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] = + `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at + positions `index_vector_dim` into A. + 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering + `S` using the `scatter_dims_to_operand_dims` map. More formally: + 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if + `k` < `scatter_dims_to_operand_dims.size`. + 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. + 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices + at `update_window_dims` in `U` according to `inserted_window_dims`. + More formally: + 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if + `k` < `update_window_dims.size`, where `window_dims_to_operand_dims` + is the monotonic function with domain [`0`, `update_window_dims.size`) + and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For + example, if `update_window_dims.size` is `4`, `operand.rank` is `6`, + and `inserted_window_dims` is {`0`, `2`} then + `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, + `3`→`5`}). + 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise. + 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise + addition. + +In summary, the scatter operation can be defined as follows. + + - Initialize `output` with `operand`, i.e. for all indices `O` in the + `operand` tensor:\ + `output`[`O`] = `operand`[`O`] + - For every index `U` in the `updates` tensor and the corresponding index `O` + in the `operand` tensor:\ + `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) + +The order in which updates are applied is non-deterministic. So, when multiple +indices in `updates` refer to the same index in `operand`, the corresponding +value in `output` will be non-deterministic. + +Note that the first parameter that is passed into the `update_computation` will +always be the current value from the `output` tensor and the second parameter +will always be the value from the `updates` tensor. This is important +specifically for cases when the `update_computation` is _not commutative_. + +Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. +the scatter op updates the elements in the input that are extracted by the +corresponding gather op. + +For a detailed informal description and examples, refer to the +"Informal Description" section under `Gather`. + ## Select See also |