diff options
Diffstat (limited to 'tensorflow/docs_src/performance/xla/operation_semantics.md')
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 68c427a316..fe9afc4ecb 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -791,8 +791,6 @@ DynamicSlice extracts a sub-array from the input array at dynamic `size_indices`, which specify the end point of exclusive slice intervals in each dimension: [start, start + size). The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. <b> `DynamicSlice(operand, start_indices, size_indices)` </b> @@ -812,6 +810,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension to avoid wrapping modulo : : : : dimension size. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) +``` + +This ensures that the extracted slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` @@ -847,8 +856,6 @@ The shape of `update` determines the shape of the sub-array of the result which is updated. The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. <b> `DynamicUpdateSlice(operand, update, start_indices)` </b> @@ -866,6 +873,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension. Value must be greater than or equal : : : : to zero. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i]) +``` + +This ensures that the updated slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` @@ -1293,6 +1311,19 @@ Infeed of the device. > which case the compiler will provide information about how the Infeed > operations are serialized in the compiled program. +## Iota + +<b> `Iota()` </b> + +Builds a constant literal on device rather than a potentially large host +transfer. Creates a rank 1 tensor of values starting at zero and incrementing +by one. + +Arguments | Type | Semantics +------------------ | --------------- | --------------------------- +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the tensor. + ## Map See also @@ -1303,7 +1334,7 @@ See also | Arguments | Type | Semantics | | ----------------- | ---------------------- | ------------------------------ | | `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} | -| `computation` | `XlaComputation` | computation of type `T_0, T_1, | +| `computation` | `XlaComputation` | computation of type `T_0, T_1, | : : : ..., T_{N + M -1} -> S` with N : : : : parameters of type T and M of : : : : arbitrary type : |