diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/scatter_expander.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/scatter_expander.cc | 350 |
1 files changed, 350 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 0000000000..45ca731153 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +using tensorflow::gtl::ArraySlice; + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector<int64> permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr<HloInstruction*> CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr<HloInstruction*> PermuteScatterAndWindowDims( + HloInstruction* updates, ArraySlice<int64> update_window_dims) { + std::vector<int64> permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr<HloInstruction*> AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector<HloInstruction*> expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector<HloInstruction*>& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Extact the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, + scatter_slice_start)); + + return StatusOr<std::vector<HloInstruction*>>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr<HloInstruction*> ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString().c_str()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector<HloInstruction*>& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr<bool> ScatterExpander::Run(HloModule* module) { + std::vector<HloInstruction*> scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla |