diff options
author | 2018-07-24 11:13:40 -0700 | |
---|---|---|
committer | 2018-07-24 11:25:57 -0700 | |
commit | 3acdbf8f904cf32e5d4d211934ee8d346aa48457 (patch) | |
tree | 1d6e8c32fb3ed8642bb709009a1f2d6402502ee6 /tensorflow/docs_src | |
parent | d53830cddfc74105e46a4bdb703cb1154a288f8f (diff) |
[XLA] Document DynamicSlice and DynamicUpdateSlice semantics.
PiperOrigin-RevId: 205858924
Diffstat (limited to 'tensorflow/docs_src')
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index d6fa8ab5f9..26a7b9e42c 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -791,8 +791,6 @@ DynamicSlice extracts a sub-array from the input array at dynamic `size_indices`, which specify the end point of exclusive slice intervals in each dimension: [start, start + size). The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. <b> `DynamicSlice(operand, start_indices, size_indices)` </b> @@ -812,6 +810,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension to avoid wrapping modulo : : : : dimension size. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) +``` + +This ensures that the extracted slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` @@ -847,8 +856,6 @@ The shape of `update` determines the shape of the sub-array of the result which is updated. The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. <b> `DynamicUpdateSlice(operand, update, start_indices)` </b> @@ -866,6 +873,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension. Value must be greater than or equal : : : : to zero. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i]) +``` + +This ensures that the updated slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` |