diff options
28 files changed, 1103 insertions, 1153 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 35de96e0aa..44140304fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // gather = s32[3,2] gather(operand, indices), - // output_window_dims={0}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3, 1} + // slice_sizes={3, 1} // // // Example of an N-D gather pulling out slices of shape [1,1,2] out of a @@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3,2] parameter(0) // indices = s32[2,2] parameter(1) // gather = s32[2,2] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0,1}, - // gather_dims_to_operand_dims={0,1}, + // offset_dims={1}, + // collapsed_slice_dims={0,1}, + // start_index_map={0,1}, // index_vector_dim=0, - // window_bounds={1,1,2} + // slice_sizes={1,1,2} xla::GatherDimensionNumbers dim_numbers; - std::vector<int64> window_bounds; - window_bounds.reserve(input_shape.dims()); + std::vector<int64> slice_sizes; + slice_sizes.reserve(input_shape.dims()); for (int64 i = 0; i < input_shape.dims(); i++) { int64 window_bound; if (axis <= i && i < (axis + num_index_dims)) { - dim_numbers.add_elided_window_dims(i); + dim_numbers.add_collapsed_slice_dims(i); window_bound = 1; } else { window_bound = input_shape.dim_size(i); } - window_bounds.push_back(window_bound); + slice_sizes.push_back(window_bound); if (i < axis) { - dim_numbers.add_output_window_dims(i); + dim_numbers.add_offset_dims(i); } else if (i >= (axis + num_index_dims)) { int64 indices_rank = indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); - dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); + dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); } } dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims()); for (int64 i = axis; i < axis + num_index_dims; i++) { - dim_numbers.add_gather_dims_to_operand_dims(i); + dim_numbers.add_start_index_map(i); } - *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 04fa10108c..febb638e5e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // We can grab entire blocks using gather if (n > block_size) { // Construct the starting indices of the diagonal blocks - auto gather_indices = + auto start_indices = Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), xla::ConstantR0<int32>(builder, block_size)), /*broadcast_sizes=*/{2}), @@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Gather the diagonal blocks xla::GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(ndims - 1); - dim_numbers.add_output_window_dims(ndims); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims); + dim_numbers.add_start_index_map(ndims - 2); + dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, gather_indices, dim_numbers, - /*window_bounds=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, + /*slice_sizes=*/{block_size, block_size}); } // The last block might be smaller than the block size, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index aa47f992bc..4dffab3c2c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1631,27 +1631,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, }); } -XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) { + tensorflow::gtl::ArraySlice<int64> slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); - TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, - GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, gather_indices_shape, - dimension_numbers, window_bounds)); + ShapeInference::InferGatherShape(input_shape, start_indices_shape, + dimension_numbers, slice_sizes)); *instr.mutable_gather_dimension_numbers() = dimension_numbers; - for (int64 bound : window_bounds) { - instr.add_gather_window_bounds(bound); + for (int64 bound : slice_sizes) { + instr.add_gather_slice_sizes(bound); } return AddInstruction(std::move(instr), HloOpcode::kGather, - {input, gather_indices}); + {input, start_indices}); }); } @@ -2906,11 +2906,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, mantissa_bits); } -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) { - return input.builder()->Gather(input, gather_indices, dimension_numbers, - window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes) { + return input.builder()->Gather(input, start_indices, dimension_numbers, + slice_sizes); } XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 78aec770a6..469d5048b2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -877,9 +877,9 @@ class XlaBuilder { const int mantissa_bits); // Enqueues a Gather node onto the computation. - XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -1328,9 +1328,9 @@ class XlaBuilder { const XlaComputation& false_computation); friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, @@ -2024,9 +2024,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); // Enqueues a Gather node onto the computation. -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 8cbdc36f84..e6130c7d76 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -792,11 +792,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -808,11 +808,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -824,11 +824,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -840,11 +840,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -856,11 +856,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -872,11 +872,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -888,11 +888,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 2e9d6be2de..891ae42141 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1672,22 +1672,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - int64 output_window_dim = - dim_numbers.output_window_dims(operand_index_dim++); + int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; operand_index.push_back(index[output_window_dim]); } } - // This is the index of the index vector in the gather_indices tensor. + // This is the index of the index vector in the start_indices tensor. IrArray::Index gather_index_index(index_type); { std::vector<llvm::Value*> gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } @@ -1700,7 +1699,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = b_->CreateSExtOrTrunc(index_component, index_type); - int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. // This means we set the iteration index to 0, so for the purpose of the diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index e3a42d0d06..9370c88710 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -27,85 +27,85 @@ namespace xla { using tensorflow::gtl::ArraySlice; static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast( - HloInstruction* gather_indices, int64 index_vector_dim) { - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices, int64 index_vector_dim) { + const Shape& start_indices_shape = start_indices->shape(); - if (gather_indices_shape.dimensions_size() == index_vector_dim) { - return gather_indices; + if (start_indices_shape.dimensions_size() == index_vector_dim) { + return start_indices; } - if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { - return gather_indices; + if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { + return start_indices; } std::vector<int64> permutation; - permutation.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + permutation.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { permutation.push_back(i); } } permutation.push_back(index_vector_dim); - return MakeTransposeHlo(gather_indices, permutation); + return MakeTransposeHlo(start_indices, permutation); } -// Canonicalizes the gather_indices tensors so that we only have deal with some +// Canonicalizes the start_indices tensors so that we only have deal with some // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. static StatusOr<HloInstruction*> CanonicalizeGatherIndices( - HloInstruction* gather_indices, int64 index_vector_dim) { + HloInstruction* start_indices, int64 index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_gather_indices, - TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + HloInstruction * transposed_start_indices, + TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); bool indices_are_scalar = - index_vector_dim == gather_indices->shape().dimensions_size(); + index_vector_dim == start_indices->shape().dimensions_size(); - // The number of dimensions in gather_indices that are index dimensions. - const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; + // The number of dimensions in start_indices that are index dimensions. + const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1; - // If there is only one index (i.e. gather_indices has rank 1 and this gather + // If there is only one index (i.e. start_indices has rank 1 and this gather // is really just a dynamic slice) add a leading degenerate dimension for // uniformity. Otherwise create a "collapsed" leading dimension that subsumes // all of the non-index-vector dimensions. - const Shape& shape = transposed_gather_indices->shape(); - if (shape.dimensions_size() == index_dims_in_gather_indices) { - return PrependDegenerateDims(transposed_gather_indices, 1); + const Shape& shape = transposed_start_indices->shape(); + if (shape.dimensions_size() == index_dims_in_start_indices) { + return PrependDegenerateDims(transposed_start_indices, 1); } else { - // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // Collapse all but the dimensions (0 or 1) in start_indices containing the // index vectors. return CollapseFirstNDims( - transposed_gather_indices, - shape.dimensions_size() - index_dims_in_gather_indices); + transposed_start_indices, + shape.dimensions_size() - index_dims_in_start_indices); } } // Expands out or contracts away the gather dimensions in the accumulator // produced by the while loop. -static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator( - const Shape& gather_indices_shape, HloInstruction* accumulator, +static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator( + const Shape& start_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { - std::vector<int64> output_gather_dim_bounds; - output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + std::vector<int64> batch_dim_bounds; + batch_dim_bounds.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { - output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + batch_dim_bounds.push_back(start_indices_shape.dimensions(i)); } } - if (output_gather_dim_bounds.empty()) { - // If output_gather_dim_bounds is empty we must be lowering a (effectively) + if (batch_dim_bounds.empty()) { + // If batch_dim_bounds is empty we must be lowering a (effectively) // dynamic-slice. In that case, there is a leading degenerate gather // dimension that we added to make this special case play well with the // general while loop which we need to remove now. return ElideDegenerateDims(accumulator, {0}); } - return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); + return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); } -// Expand an index vector from the gather_indices tensor into a vector that can +// Expand an index vector from the start_indices tensor into a vector that can // be used to dynamic-slice out of the gather operand. static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, @@ -121,10 +121,8 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( std::vector<HloInstruction*> expanded_index_components; for (int i = 0; i < operand_rank; i++) { - int64 index_vector_dim_index = - FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.gather_dims_to_operand_dims_size()) { + int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i); + if (index_vector_dim_index != dim_numbers.start_index_map_size()) { TF_ASSIGN_OR_RETURN( HloInstruction * component_to_concat, MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, @@ -147,10 +145,10 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody( const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); CHECK_EQ(incoming_loop_state.size(), 3); HloInstruction* const operand = incoming_loop_state[0]; - HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const start_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; - bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + bool has_scalar_indices = start_indices->shape().dimensions_size() == 1; CHECK_EQ(has_scalar_indices, dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); @@ -163,24 +161,24 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody( HloInstruction* index_vector; if (has_scalar_indices) { - // In this case gather_indices has rank 1 and induction_var_as_vector (of + // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. TF_ASSIGN_OR_RETURN( index_vector, - MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1})); } else { - // In this case gather_indices has rank 2 and induction_var_as_vector (of + // In this case start_indices has rank 2 and induction_var_as_vector (of // shape {1}) is an index into just the first dimension of this rank 2 // tensor. TF_ASSIGN_OR_RETURN( - HloInstruction * index_into_gather_indices, + HloInstruction * index_into_start_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); - int64 index_vector_size = gather_indices->shape().dimensions(1); + int64 index_vector_size = start_indices->shape().dimensions(1); TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, - MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + MakeDynamicSliceHlo(start_indices, index_into_start_indices, {1, index_vector_size})); TF_ASSIGN_OR_RETURN(index_vector, @@ -194,26 +192,26 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody( TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_window_bounds())); + gather.gather_slice_sizes())); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_with_dims_elided, + HloInstruction* const gathered_slice_with_dims_collapsed, ElideDegenerateDims(gathered_slice, - AsInt64Slice(dim_numbers.elided_window_dims()))); + AsInt64Slice(dim_numbers.collapsed_slice_dims()))); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_for_update, - PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); + HloInstruction* const gathered_slice_for_update, + PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1)); TF_ASSIGN_OR_RETURN( - HloInstruction * index_vector_into_accumulator, + HloInstruction* const index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/ - gathered_slice_with_dims_elided->shape().dimensions_size())); + gathered_slice_with_dims_collapsed->shape().dimensions_size())); TF_ASSIGN_OR_RETURN( - HloInstruction * updated_accumulator, + HloInstruction* const updated_accumulator, MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, index_vector_into_accumulator)); @@ -221,19 +219,19 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody( // WhileUtil::MakeCountedLoop functions takes care of the induction variable // and the while loop exit condition. return StatusOr<std::vector<HloInstruction*>>{ - {operand, gather_indices, updated_accumulator}}; + {operand, start_indices, updated_accumulator}}; } static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice<int64> window_bounds, int64 gather_loop_trip_count, + ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector<int64> accumulator_state_shape_dims; - accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64 i = 0; i < window_bounds.size(); i++) { - if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { - accumulator_state_shape_dims.push_back(window_bounds[i]); + for (int64 i = 0; i < slice_sizes.size(); i++) { + if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + accumulator_state_shape_dims.push_back(slice_sizes[i]); } } return BroadcastZeros(computation, element_type, @@ -241,23 +239,23 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue( } // `accumulator` is almost the tensor the gather operation would have produced, -// except that it has the dimensions in the wrong order -- the gather dimensions -// are the major dimensions and the window dimensions are the minor dimensions. +// except that it has the dimensions in the wrong order -- the batch dimensions +// are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. -static StatusOr<HloInstruction*> PermuteGatherAndWindowDims( - HloInstruction* accumulator, ArraySlice<int64> output_window_dims, +static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims( + HloInstruction* accumulator, ArraySlice<int64> offset_dims, int64 output_rank) { std::vector<int64> permutation; permutation.reserve(output_rank); - int64 gather_idx_counter = 0; - int64 window_idx_counter = output_rank - output_window_dims.size(); + int64 batch_idx_counter = 0; + int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_window_dim = c_binary_search(output_window_dims, i); - if (is_window_dim) { - permutation.push_back(window_idx_counter++); + bool is_offset_dim = c_binary_search(offset_dims, i); + if (is_offset_dim) { + permutation.push_back(offset_idx_counter++); } else { - permutation.push_back(gather_idx_counter++); + permutation.push_back(batch_idx_counter++); } } @@ -268,11 +266,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims( // // We follow the following steps in sequence: // -// 1. We canonicalize the gather_indices tensor such that it has rank +// 1. We canonicalize the start_indices tensor such that it has rank // 2 (i.e. is a matrix) where each row is an index vector into the // operand. // 2. We iterate over the set of indices in the canonicalized -// gather_indices tensor using a while loop, accumulating slices +// start_indices tensor using a while loop, accumulating slices // of the operand tensor into an accumulator using // DynamicUpdateSlice. // 3. The accumulator result from the while loop from (2) is then @@ -287,11 +285,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims( // operand = s32[3,3] parameter(0) // indices = s32[2,2] parameter(1) // ROOT gather = s32[2,3,2] gather(operand, indices), -// output_window_dims={1}, -// elided_window_dims={1}, -// gather_dims_to_operand_dims={1}, +// offset_dims={1}, +// collapsed_slice_dims={1}, +// start_index_map={1}, // index_vector_dim=2, -// window_bounds={3, 1} +// slice_sizes={3, 1} // } // // We'd first reshape indices to s32[4,1], where each row is an index @@ -305,8 +303,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); - HloInstruction* gather_indices = gather_instr->mutable_operand(1); - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); @@ -314,9 +312,9 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather( gather_instr->gather_dimension_numbers(); int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= gather_indices_shape.dimensions(i); + gather_loop_trip_count *= start_indices_shape.dimensions(i); } } @@ -327,24 +325,24 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather( gather_instr->ToString().c_str()); } - TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, - CanonicalizeGatherIndices( - gather_indices, dim_numbers.index_vector_dim())); + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_start_indices, + CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim())); CHECK_EQ(gather_loop_trip_count, - canonical_gather_indices->shape().dimensions(0)); + canonical_start_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, CreateGatherLoopAccumulatorInitValue( computation, output_shape.element_type(), - gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_slice_sizes(), gather_loop_trip_count, gather_instr->gather_dimension_numbers())); StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( computation, gather_loop_trip_count, - {operand, canonical_gather_indices, accumulator_init}, + {operand, canonical_start_indices, accumulator_init}, [&](HloInstruction* indvar, const std::vector<HloInstruction*>& loop_state) { return GatherLoopBody(*gather_instr, indvar, loop_state); @@ -356,13 +354,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather( HloInstruction* accumulator_result = gather_loop_result.back(); TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, - dim_numbers.index_vector_dim())); + HloInstruction* const accumulator_with_batch_dims_decanonicalized, + AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result, + dim_numbers.index_vector_dim())); - return PermuteGatherAndWindowDims( - accumulator_with_output_gather_dims_decanonicalized, - AsInt64Slice(dim_numbers.output_window_dims()), output_rank); + return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized, + AsInt64Slice(dim_numbers.offset_dims()), + output_rank); } StatusOr<bool> GatherExpander::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 020ffcd106..141dd4d6f1 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -28,11 +28,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2147483647,5] parameter(1) ROOT gather = s32[2147483647,3,5] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, @@ -55,11 +55,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 9d24b42401..fa218657fe 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -139,7 +139,7 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; - repeated int64 gather_window_bounds = 34; + repeated int64 gather_slice_sizes = 34; // Compute Host. string channel_name = 41; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51353eea6e..36d6a2eed6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -555,43 +555,39 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -// Returns an ShapeUtil::IndexIterationSpace that iterates over the output -// gather dimensions while keeping the rest of the output dimensions clamped to -// 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { int64 output_rank = output_shape.dimensions_size(); std::vector<int64> index_base(output_rank, 0); std::vector<int64> index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_gather_dim = - !c_binary_search(dim_numbers.output_window_dims(), i); - index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) - : 1); + bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i); + index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } return {std::move(index_base), std::move(index_count), std::vector<int64>(output_rank, 1)}; } -// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( - int64 output_rank, ArraySlice<int64> window_bounds, +ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( + int64 output_rank, ArraySlice<int64> slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector<int64> index_base(output_rank, 0); std::vector<int64> index_count(output_rank, 1); - int64 window_bounds_idx = 0; + int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { - bool is_output_window_dim = - c_binary_search(dim_numbers.output_window_dims(), i); + bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.elided_window_dims(), - window_bounds_idx)) { - window_bounds_idx++; + while (c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { + slice_sizes_idx++; } - index_count[i] = window_bounds[window_bounds_idx++]; + index_count[i] = slice_sizes[slice_sizes_idx++]; } } @@ -599,30 +595,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( std::vector<int64>(output_rank, 1)}; } -// This functor computes the contribution of gather_indices to an input index +// This functor computes the contribution of start_indices to an input index // corresponding to an output index. That is, given an output index I, it picks -// out the gather output indices in I and uses them to look up a gather index, -// G, from the gather indices tensor, and expands G into the input space -// according to gather_dims_to_operand_dims. -class OutputGatherIndexToInputIndex { +// out the batch indices in I and uses them to look up a starting index, G, from +// the start indices tensor, and expands G into the input space according to +// start_index_map. +class OutputBatchIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputGatherIndexToInputIndex( + explicit OutputBatchIndexToInputIndex( const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, - const Shape& output_shape, const Literal* gather_indices) - : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + const Shape& output_shape, const Literal* start_indices) + : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - output_dim_is_gather_dims_.push_back( - !c_binary_search(dim_numbers_.output_window_dims(), i)); + output_dim_is_batch_dims_.push_back( + !c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = - std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), - c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + std::distance(dim_numbers_.start_index_map().begin(), + c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == - dim_numbers_.gather_dims_to_operand_dims_size()) { + dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); } else { input_dim_value_to_index_vector_.push_back( @@ -630,14 +626,14 @@ class OutputGatherIndexToInputIndex { } } - index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + index_vector_index_.resize(start_indices_.shape().dimensions_size()); input_index_.resize(input_shape.dimensions_size()); int64 index_vector_size = - gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); index_vector_.resize(index_vector_size); } - // Returns the contribution of gather_indices to the input index corresponding + // Returns the contribution of start_indices to the input index corresponding // to output_index. See gather_inner_loop_body. // // This is conceptually a stateless transformation from output_index to the @@ -659,7 +655,7 @@ class OutputGatherIndexToInputIndex { } private: - // Propagates the gather index dimensions from the output index into + // Propagates the batch dimensions from the output index into // index_vector_index_ by mutating index_vector_index_ in place. Does not // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. @@ -667,7 +663,7 @@ class OutputGatherIndexToInputIndex { ArraySlice<int64> output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { - if (!output_dim_is_gather_dims_[i]) { + if (!output_dim_is_batch_dims_[i]) { continue; } @@ -679,14 +675,14 @@ class OutputGatherIndexToInputIndex { } } - // Populates index_vector_ by iterating over gather_indices_ according to + // Populates index_vector_ by iterating over start_indices_ according to // index_vector_index_. Status FetchIndexVector() { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( - index_vector_index_)); + TF_ASSIGN_OR_RETURN(index_vector_[i], + start_indices_.GetIntegralAsS64(index_vector_index_)); } return Status::OK(); } @@ -708,15 +704,15 @@ class OutputGatherIndexToInputIndex { // PropagateIndexVectorToInputIndex. std::vector<int64> input_dim_value_to_index_vector_; - // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // output_dim_is_batch_dims_[i] is true iff the output index i is a gather // dimension. - std::vector<bool> output_dim_is_gather_dims_; + std::vector<bool> output_dim_is_batch_dims_; - // The buffer into which we construct an index into gather_indices_ to fetch + // The buffer into which we construct an index into start_indices_ to fetch // the index vector. std::vector<int64> index_vector_index_; - // The index vector fetched from gather_indices_. + // The index vector fetched from start_indices_. std::vector<int64> index_vector_; // The result computed by this functor. operator() returns an ArraySlice into @@ -724,24 +720,23 @@ class OutputGatherIndexToInputIndex { std::vector<int64> input_index_; const GatherDimensionNumbers& dim_numbers_; - const Literal& gather_indices_; + const Literal& start_indices_; }; -// This functor computes the contribution of the window indices in an output +// This functor computes the contribution of the offset indices in an output // index to an input index. That is, given an output index I it picks out the -// output window indices in I and expands it into a window index into the input -// shape. -class OutputWindowIndexToInputIndex { +// output offset indices in I and expands it into an index into the input shape. +class OutputOffsetIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputWindowIndexToInputIndex( + explicit OutputOffsetIndexToInputIndex( const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, const Shape& output_shape) { std::vector<int64> window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.output_window_dims(), i)) { + if (c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -750,7 +745,7 @@ class OutputWindowIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( @@ -808,20 +803,20 @@ class OutputWindowIndexToInputIndex { // Rehapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if -// there is one) to `reshaped_gather_indices`. +// there is one) to `reshaped_start_indices`. static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices( - int64 index_vector_dim, const Literal& gather_indices, - std::unique_ptr<Literal>* reshaped_gather_indices) { - if (gather_indices.shape().dimensions_size() != index_vector_dim) { - return std::cref(gather_indices); + int64 index_vector_dim, const Literal& start_indices, + std::unique_ptr<Literal>* reshaped_start_indices) { + if (start_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(start_indices); } - std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(), - gather_indices.shape().dimensions().end()); + std::vector<int64> new_shape(start_indices.shape().dimensions().begin(), + start_indices.shape().dimensions().end()); new_shape.push_back(1); - TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, - gather_indices.Reshape(new_shape)); - return std::cref(**reshaped_gather_indices); + TF_ASSIGN_OR_RETURN(*reshaped_start_indices, + start_indices.Reshape(new_shape)); + return std::cref(**reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { @@ -830,34 +825,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr<Literal> reshaped_gather_indices; + std::unique_ptr<Literal> reshaped_start_indices; TF_ASSIGN_OR_RETURN( - const Literal& gather_indices, + const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), GetEvaluatedLiteralFor(gather->operand(1)), - &reshaped_gather_indices)); + &reshaped_start_indices)); // We iterate over the gather dimensions in the output shape in an outer loop // nest, and iterate over the window dimensions in the output shape in an // inner loop nest. - ShapeUtil::IndexIterationSpace gather_indices_iteration_space = - IterationSpaceForOutputGatherIndices(shape, dim_numbers); - ShapeUtil::IndexIterationSpace window_indices_iteration_space = - IterationSpaceForOutputWindowIndices( - shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + ShapeUtil::IndexIterationSpace start_indices_iteration_space = + IterationSpaceForOutputBatchIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace offset_indices_iteration_space = + IterationSpaceForOutputOffsetIndices( + shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); // Scratch buffers that hold an index in the output shape and the // corresponding index in the input shape. std::vector<int64> input_index(operand.shape().dimensions_size()); std::vector<int64> output_index(gather->shape().dimensions_size()); - std::vector<int64> input_gather_index_clamped( - operand.shape().dimensions_size()); + std::vector<int64> input_index_clamped(operand.shape().dimensions_size()); - OutputGatherIndexToInputIndex output_gather_index_to_input_index( + OutputBatchIndexToInputIndex output_batch_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), - /*output_shape=*/shape, &gather_indices); - OutputWindowIndexToInputIndex output_window_index_to_input_index( + /*output_shape=*/shape, &start_indices); + OutputOffsetIndexToInputIndex output_offset_index_to_input_index( gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), /*output_shape=*/shape); @@ -869,29 +863,29 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { ArraySlice<int64> output_gather_index) -> StatusOr<bool> { TF_ASSIGN_OR_RETURN( ArraySlice<int64> input_window_index, - output_window_index_to_input_index(output_window_index)); + output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } for (int i = 0, e = input_gather_index.size(); i < e; i++) { int64 output_dim = - output_window_index_to_input_index.input_dim_value_to_output_index(i); + output_offset_index_to_input_index.input_dim_value_to_output_index(i); // If 'output_dim' is -1, it means 'i' is an elided window dim. This means // we set the iteration index to 0, so for the purpose of the following // calculations we can consider the output dimension size to be 1. int64 output_dim_size = output_dim == -1 ? 1 : shape.dimensions(output_dim); // Clamp the gather index so that the gather region fits in the operand. - // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // input_index_clamped[i] = clamp(input_gather_index[i], 0, // operand_shape.dimensions(i) - // output_dim_size); - input_gather_index_clamped[i] = + input_index_clamped[i] = std::min(operand_shape.dimensions(i) - output_dim_size, std::max(0LL, input_gather_index[i])); } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + input_index[i] = input_index_clamped[i] + input_window_index[i]; DCHECK_GE(input_index[i], 0); DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } @@ -902,18 +896,17 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { auto gather_outer_loop_body = [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> { - TF_ASSIGN_OR_RETURN( - ArraySlice<int64> input_gather_index, - output_gather_index_to_input_index(output_gather_index)); + TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index, + output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, window_indices_iteration_space, + shape, offset_indices_iteration_space, std::bind(gather_inner_loop_body, std::placeholders::_1, input_gather_index, output_gather_index))); return true; }; TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, gather_indices_iteration_space, gather_outer_loop_body)); + shape, start_indices_iteration_space, gather_outer_loop_body)); evaluated_[gather] = std::move(result); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 1cef3549e0..1394be68e4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1826,21 +1826,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1851,21 +1850,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1876,22 +1874,22 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR3<int32>( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1902,11 +1900,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1914,11 +1912,11 @@ ENTRY main { LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1930,11 +1928,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1942,11 +1940,11 @@ ENTRY main { LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1957,21 +1955,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({1, 1}); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1982,21 +1979,21 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2007,20 +2004,19 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2031,21 +2027,21 @@ ENTRY main { operand = s32[3] parameter(0) indices = s32[2,2,1] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1} + slice_sizes={1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 4aaef1941b..57e75cf931 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -392,13 +392,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers = MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); - std::vector<int64> gather_window_bounds; - for (int64 bound : proto.gather_window_bounds()) { - gather_window_bounds.push_back(bound); + std::vector<int64> gather_slice_sizes; + for (int64 bound : proto.gather_slice_sizes()) { + gather_slice_sizes.push_back(bound); } - instruction = - CreateGather(proto.shape(), operands(0), operands(1), - *gather_dimension_numbers, gather_window_bounds); + instruction = CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_slice_sizes); break; } case HloOpcode::kScatter: { @@ -1078,11 +1077,11 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) { - return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices, - gather_dim_numbers, window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes) { + return MakeUnique<HloGatherInstruction>(shape, operand, start_indices, + gather_dim_numbers, slice_sizes); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter( @@ -3226,9 +3225,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast<HloGatherInstruction>(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds() - const { - return Cast<HloGatherInstruction>(this)->gather_window_bounds(); +tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const { + return Cast<HloGatherInstruction>(this)->gather_slice_sizes(); } const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b3eee90099..8d8f149ee3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -667,9 +667,9 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateGather( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); static std::unique_ptr<HloInstruction> CreateScatter( const Shape& shape, HloInstruction* operand, @@ -1489,8 +1489,8 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; - // Delegates to HloGatherInstruction::gather_window_bounds. - tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const; + // Delegates to HloGatherInstruction::gather_slice_sizes. + tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8a694dde80..504b13043f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) { TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); @@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=4, window_bounds={30,29,28,27,26}"); + "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyGather_1) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); @@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1419,10 +1417,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=2, window_bounds={30,29,28,27,26}"); + "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyScatter) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 233cdda7b0..4fdf4360e6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1965,51 +1965,50 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( } HloGatherInstruction::HloGatherInstruction( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) + tensorflow::gtl::ArraySlice<int64> slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); - AppendOperand(gather_indices); + AppendOperand(start_indices); gather_dimension_numbers_ = MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); + c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string offset_dims = + StrCat("offset_dims={", + Join(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = + StrCat("collapsed_slice_dims={", + Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + string start_index_map = + StrCat("start_index_map={", + Join(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join<std::initializer_list<string>>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, + {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + tensorflow::gtl::ArraySlice<int64> offset_dims, + tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims, + tensorflow::gtl::ArraySlice<int64> start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); + for (int64 output_window_dim : offset_dims) { + gather_dim_numbers.add_offset_dims(output_window_dim); } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); + for (int64 elided_window_dim : collapsed_slice_dims) { + gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim); } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + for (int64 gather_dim_to_input_dim : start_index_map) { + gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } gather_dim_numbers.set_index_vector_dim(index_vector_dim); @@ -2019,8 +2018,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { HloInstructionProto HloGatherInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); + for (int64 bound : gather_slice_sizes()) { + proto.add_gather_slice_sizes(bound); } return proto; } @@ -2028,7 +2027,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; + StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2039,7 +2038,7 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_window_bounds() == casted_other.gather_window_bounds(); + gather_slice_sizes() == casted_other.gather_slice_sizes(); } std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( @@ -2049,7 +2048,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return MakeUnique<HloGatherInstruction>( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_window_bounds()); + gather_slice_sizes()); } HloScatterInstruction::HloScatterInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 546949bc72..803dbeabeb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1212,15 +1212,15 @@ class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { - return gather_window_bounds_; + tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const { + return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; @@ -1229,9 +1229,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + tensorflow::gtl::ArraySlice<int64> offset_dims, + tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims, + tensorflow::gtl::ArraySlice<int64> start_index_map, int64 index_vector_dim); private: @@ -1247,7 +1247,7 @@ class HloGatherInstruction : public HloInstruction { HloCloneContext* context) const override; std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; - std::vector<int64> gather_window_bounds_; + std::vector<int64> gather_slice_sizes_; }; class HloScatterInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index eb48337cd7..ab57a8b07f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1233,22 +1233,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional<std::vector<tensorflow::int64>> output_window_dims; - attrs["output_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional<std::vector<tensorflow::int64>> elided_window_dims; - attrs["elided_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims; - attrs["gather_dims_to_operand_dims"] = {/*required=*/true, - AttrTy::kBracedInt64List, - &gather_dims_to_operand_dims}; + optional<std::vector<tensorflow::int64>> offset_dims; + attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, + &offset_dims}; + optional<std::vector<tensorflow::int64>> collapsed_slice_dims; + attrs["collapsed_slice_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; + optional<std::vector<tensorflow::int64>> start_index_map; + attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, + &start_index_map}; optional<tensorflow::int64> index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional<std::vector<tensorflow::int64>> window_bounds; - attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, - &window_bounds}; + optional<std::vector<tensorflow::int64>> slice_sizes; + attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, + &slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1257,14 +1256,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, GatherDimensionNumbers dim_numbers = HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*offset_dims=*/*offset_dims, + /*collapsed_slice_dims=*/*collapsed_slice_dims, + /*start_index_map=*/*start_index_map, /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( - shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], - dim_numbers, *window_bounds)); + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + dim_numbers, *slice_sizes)); break; } case HloOpcode::kScatter: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 6fa3c63d83..0d7919346b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -752,10 +752,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] { "gather", R"(HloModule StringifyGather -ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1030,8 +1030,8 @@ R"(HloModule gather ENTRY Gather { input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 949a4d1110..ac1a663633 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -572,7 +572,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather, ShapeInference::InferGatherShape( gather->operand(0)->shape(), gather->operand(1)->shape(), - gather->gather_dimension_numbers(), gather->gather_window_bounds())); + gather->gather_dimension_numbers(), gather->gather_slice_sizes())); } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 3531b7223f..8d17c03afc 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -153,7 +153,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), + instr->gather_slice_sizes(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kReshape) { @@ -251,24 +251,23 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source, + tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } - CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + CHECK_EQ(dim_numbers.start_index_map_size(), 1); - // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should - // it become relevant. + // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here, + // should it become relevant. - if (dim_numbers.elided_window_dims_size() != 1 || - dim_numbers.elided_window_dims(0) != - dim_numbers.gather_dims_to_operand_dims(0)) { + if (dim_numbers.collapsed_slice_dims_size() != 1 || + dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { VLOG(3) << "ComputeArrayForGather: gather operations must elide " - "gather_dims_to_operand_dims[0] and " - "gather_dims_to_operand_dims[0] only"; + "start_index_map[0] and " + "start_index_map[0] only"; return nullptr; } @@ -277,21 +276,21 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather( // arrays from an array of size [7,4,6]. We check that condition down below: for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { - if (i != dim_numbers.elided_window_dims(0) && - source->shape().dimensions(i) != window_bounds[i]) { - VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + if (i != dim_numbers.collapsed_slice_dims(0) && + source->shape().dimensions(i) != slice_sizes[i]) { + VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i << "] != source->shape().dimensions(" << i << ") -- " - << source->shape().dimensions(i) << " vs. " << window_bounds[i] - << " with dim_numbers.elided_window_dims(0) = " - << dim_numbers.elided_window_dims(0); + << source->shape().dimensions(i) << " vs. " << slice_sizes[i] + << " with dim_numbers.collapsed_slice_dims(0) = " + << dim_numbers.collapsed_slice_dims(0); return nullptr; } } - int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + int64 source_dim = dim_numbers.start_index_map(0); std::vector<int64> output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } @@ -735,11 +734,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( // operand = s32[3,5,2] constant({...}) // indices = s32[7] parameter(0) // gather = s32[3,2,7] gather(operand, indices), - // output_window_dims={0,1}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0,1}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3,1,2} + // slice_sizes={3,1,2} // reshape = s32[6,7] reshape(gather) // // In this case the gather maps to: diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index e923dc39f7..675eb31d26 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -265,7 +265,7 @@ class IndexedArrayAnalysis { StatusOr<Array*> ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source, + tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source, Array* indices); StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 5f4b42799b..97052edf7d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -82,11 +82,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -102,11 +102,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -122,11 +122,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} } )"; @@ -141,11 +141,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,2}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0,2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3,1} + slice_sizes={1,3,1} } )"; @@ -160,11 +160,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2,3] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} } )"; @@ -179,11 +179,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,2} + slice_sizes={1,2} } )"; @@ -199,17 +199,17 @@ ENTRY main { indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT gather_b = s32[2,3] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -228,17 +228,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[2] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -256,17 +256,17 @@ ENTRY main { indices_a = s32[2] parameter(1) indices_b = s32[5,7] parameter(2) gather_a = s32[2,6] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} } )"; @@ -284,17 +284,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[4,8] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=2, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -312,11 +312,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2] reshape(gather) } )"; @@ -333,11 +333,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,7] reshape(gather) } )"; @@ -358,11 +358,11 @@ ENTRY main { {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[5,7] parameter(0) gather = s32[5,2,6,7] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,2,6} + slice_sizes={1,2,6} ROOT reshape = s32[5,3,4,7] reshape(gather) } )"; @@ -381,11 +381,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -408,14 +408,14 @@ ENTRY main { operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) - g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, - elided_window_dims={0}, gather_dims_to_operand_dims={0}, - index_vector_dim=2, window_bounds={1,3} + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, + index_vector_dim=2, slice_sizes={1,3} i.1 = s64[1] parameter(1) - g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, - elided_window_dims={1}, gather_dims_to_operand_dims={1}, - index_vector_dim=1, window_bounds={1,1,3} + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2}, + collapsed_slice_dims={1}, start_index_map={1}, + index_vector_dim=1, slice_sizes={1,1,3} ROOT reshape = s32[1,3]{1,0} reshape(g.1) } @@ -441,11 +441,11 @@ ENTRY main { operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -469,11 +469,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1,2}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={1,1,6} + slice_sizes={1,1,6} ROOT reshape = s32[1,1,1,6] reshape(gather) } )"; @@ -500,11 +500,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), - output_window_dims={2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,5,6] reshape(gather) } )"; @@ -530,11 +530,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,2,3] reshape(gather) } )"; @@ -562,11 +562,11 @@ ENTRY main { {{1,2},{3,4},{5,6},{7,8},{9,10}}}) indices = s32[7] parameter(0) gather = s32[3,2,7] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0,1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1,2} + slice_sizes={3,1,2} ROOT reshape = s32[6,7] reshape(gather) } )"; @@ -594,11 +594,11 @@ ENTRY main { {{1},{2},{3},{4}}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6,1] gather(operand, indices), - output_window_dims={1,3}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,3}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4,1} + slice_sizes={1,4,1} ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) } )"; @@ -623,11 +623,11 @@ ENTRY main { operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT tanh = f32[5,4] tanh(gather) } )"; @@ -650,11 +650,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -678,11 +678,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) } )"; @@ -706,11 +706,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) } )"; @@ -733,11 +733,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -760,11 +760,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -808,11 +808,11 @@ ENTRY main { dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -835,11 +835,11 @@ ENTRY main { dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} } )"; @@ -863,11 +863,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -892,11 +892,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; @@ -921,11 +921,11 @@ ENTRY main { dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} @@ -952,11 +952,11 @@ ENTRY main { dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec5743a777..cc1ec1704e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2492,201 +2492,196 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, static Status ValidateGatherDimensionNumbers( const Shape& input_shape, - tensorflow::gtl::ArraySlice<int64> gather_indices_shape, + tensorflow::gtl::ArraySlice<int64> start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.output_window_dims())) { + if (!c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.output_window_dims()) != - dim_numbers.output_window_dims().end()) { + if (c_adjacent_find(dim_numbers.offset_dims()) != + dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size() - 1; + output_offset_dim_count + start_indices_shape.size() - 1; - for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { - int64 window_index = dim_numbers.output_window_dims(i); - if (window_index < 0 || window_index >= output_shape_rank) { + for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { + int64 offset_dim = dim_numbers.offset_dims(i); + if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Window index %d in gather op is out of bounds; got %lld, but should " + "Offset dimension %d in gather op is out of bounds; got %lld, but " + "should " "have been in [0,%lld).", - i, window_index, output_shape_rank); + i, offset_dim, output_shape_rank); } } - if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape[dim_numbers.index_vector_dim()]) { + if (dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "Gather op has %d elements in gather_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of gather_indices is " + "Gather op has %d elements in start_index_map and the " + "bound of dimension index_vector_dim=%lld of start_indices is " "%lld. These two numbers must be equal.", - dim_numbers.gather_dims_to_operand_dims_size(), - dim_numbers.index_vector_dim(), - gather_indices_shape[dim_numbers.index_vector_dim()]); + dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), + start_indices_shape[dim_numbers.index_vector_dim()]); } - for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { - int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); - if (gather_dim_to_input_dim < 0 || - gather_dim_to_input_dim >= input_shape.dimensions_size()) { + for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { + int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i); + if (operand_dim_for_start_index_i < 0 || + operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", - input_shape.dimensions_size(), i, gather_dim_to_input_dim); + "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } - std::vector<int64> sorted_gather_dims_to_operand_dims( - dim_numbers.gather_dims_to_operand_dims().begin(), - dim_numbers.gather_dims_to_operand_dims().end()); + std::vector<int64> sorted_start_index_map( + dim_numbers.start_index_map().begin(), + dim_numbers.start_index_map().end()); - c_sort(sorted_gather_dims_to_operand_dims); + c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != - sorted_gather_dims_to_operand_dims.end()) { + if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) { return InvalidArgument( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + Join(dim_numbers.start_index_map(), ", ").c_str()); } - for (int64 elided_dim : dim_numbers.elided_window_dims()) { - if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { + if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid elided_window_dims set in gather op; valid range is [0, " + "Invalid collapsed_slice_dims set in gather op; valid range is [0, " "%d), got: %lld.", - input_shape.dimensions_size(), elided_dim); + input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.elided_window_dims())) { + if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( - "elided_window_dims in gather op must be sorted; got: %s", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + "collapsed_slice_dims in gather op must be sorted; got: %s", + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.elided_window_dims()) != - dim_numbers.elided_window_dims().end()) { + if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( - "Repeated dimensions not allowed in elided_window_dims in gather op; " + "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } return Status::OK(); } /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) { + tensorflow::gtl::ArraySlice<int64> slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( - ExpectArray(gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(start_indices_shape, "gather indices operand of gather op")); - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape).c_str()); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if // index_vector_dim is rank(P). The bounds of this expanded shape is - // stored in expanded_gather_indices_shape. + // stored in expanded_start_indices_shape. - if (gather_indices_shape.dimensions_size() < + if (start_indices_shape.dimensions_size() < gather_dim_numbers.index_vector_dim() || gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather index leaf dimension must be within [0, rank(gather_indices) + " - "1). rank(gather_indices) is %d and gather index leaf dimension is " + "Gather index leaf dimension must be within [0, rank(start_indices) + " + "1). rank(start_indices) is %d and gather index leaf dimension is " "%lld.", - gather_indices_shape.dimensions_size(), + start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } - std::vector<int64> expanded_gather_indices_shape; - expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == + std::vector<int64> expanded_start_indices_shape; + expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); + c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); + if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { - expanded_gather_indices_shape.push_back(1); + expanded_start_indices_shape.push_back(1); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( - input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + input_shape, expanded_start_indices_shape, gather_dim_numbers)); - if (window_bounds.size() != input_shape.dimensions_size()) { + if (slice_sizes.size() != input_shape.dimensions_size()) { return InvalidArgument( - "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d.", - window_bounds.size(), input_shape.dimensions_size()); + "Gather op must have one slice size for every input dimension; got: " + "len(slice_sizes)=%lu, input_shape.rank=%d.", + slice_sizes.size(), input_shape.dimensions_size()); } - if (window_bounds.size() != - gather_dim_numbers.output_window_dims_size() + - gather_dim_numbers.elided_window_dims_size()) { + if (slice_sizes.size() != + gather_dim_numbers.offset_dims_size() + + gather_dim_numbers.collapsed_slice_dims_size()) { return InvalidArgument( - "All components of the window index in a gather op must either be a " - "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s.", - window_bounds.size(), - Join(gather_dim_numbers.output_window_dims(), ",").c_str(), - Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + "All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " + "output_slice_sizes=%s, collapsed_slice_dims=%s.", + slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), + Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); } - for (int i = 0; i < window_bounds.size(); i++) { - int64 window_bound = window_bounds[i]; - int64 corresponding_input_bound = input_shape.dimensions(i); - if (window_bound < 0 || window_bound > corresponding_input_bound) { + for (int i = 0; i < slice_sizes.size(); i++) { + int64 slice_size = slice_sizes[i]; + int64 corresponding_input_size = input_shape.dimensions(i); + if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( - "Window bound at index %d in gather op is out of range, must be " - "within " - "[0, %lld), got %lld.", - i, corresponding_input_bound + 1, window_bound); + "Slice size at index %d in gather op is out of range, must be " + "within [0, %lld), got %lld.", + i, corresponding_input_size + 1, slice_size); } } - for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { - if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( - "Gather op can only elide window indices with bound 1, but bound is " + "Gather op can only collapse slice dims with bound 1, but bound is " "%lld for index %lld at position %d.", - window_bounds[gather_dim_numbers.elided_window_dims(i)], - gather_dim_numbers.elided_window_dims(i), i); + slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], + gather_dim_numbers.collapsed_slice_dims(i), i); } } - int64 result_rank = gather_dim_numbers.output_window_dims_size() + - (expanded_gather_indices_shape.size() - 1); - int64 window_dims_seen = 0; + int64 result_rank = gather_dim_numbers.offset_dims_size() + + (expanded_start_indices_shape.size() - 1); + int64 offset_dims_seen = 0; int64 gather_dims_seen = 0; std::vector<int64> output_dim_bounds; output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; - bool is_window_index = - c_binary_search(gather_dim_numbers.output_window_dims(), i); + bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.elided_window_dims(), - window_dims_seen)) { - window_dims_seen++; + while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { + offset_dims_seen++; } - current_bound = window_bounds[window_dims_seen++]; + current_bound = slice_sizes[offset_dims_seen++]; } else { if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { gather_dims_seen++; } - current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_bounds.push_back(current_bound); @@ -2837,25 +2832,25 @@ Status ValidateScatterDimensionNumbers( scatter_dim_numbers)); int64 inserted_dims_seen = 0; - std::vector<int64> max_update_window_bounds; + std::vector<int64> max_update_slice_sizes; for (int i = 0; i < operand_shape.dimensions_size(); ++i) { if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; } else { - max_update_window_bounds.push_back(operand_shape.dimensions(i)); + max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } } for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { auto update_window_dim = scatter_dim_numbers.update_window_dims(i); if (updates_shape.dimensions(update_window_dim) > - max_update_window_bounds[i]) { + max_update_slice_sizes[i]) { return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " "%lld, updates bound is %lld, operand bound is %lld.", update_window_dim, updates_shape.dimensions(update_window_dim), - max_update_window_bounds[i]); + max_update_slice_sizes[i]); } } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index bfd79a4433..4974ac9916 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -276,9 +276,9 @@ class ShapeInference { // with the given input shape, gather indices shape and gather dimension // numbers. static StatusOr<Shape> InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds); + tensorflow::gtl::ArraySlice<int64> slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index a73fa181cd..4ed8fc6b86 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1654,11 +1654,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); @@ -1669,11 +1669,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{1}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1684,11 +1684,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{4}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1700,11 +1700,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) @@ -1717,11 +1717,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1735,11 +1735,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1749,16 +1749,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape( - f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3, 4}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{0, 1, 2, 3, 4}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) @@ -1772,11 +1771,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{0, 1, 2, 3}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/0), - /*window_bounds=*/{1, 30, 29, 28, 27})); + /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) @@ -1787,11 +1786,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) @@ -1802,11 +1801,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) @@ -1817,11 +1816,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) @@ -1833,11 +1832,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 8, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 8, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1850,11 +1849,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1867,14 +1866,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 99, 100, 101}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 99, 100, 101}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 2 in gather op is out of bounds")) + HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } @@ -1883,14 +1882,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 9}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 9}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 4 in gather op is out of bounds")) + HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } @@ -1899,16 +1898,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{4}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr("All components of the window index in a gather op must either " - "be a output window index or explicitly elided")) + HasSubstr("All components of the offset index in a gather op must either " + "be a offset dimension or explicitly collapsed")) << statusor.status(); } @@ -1917,14 +1916,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Invalid elided_window_dims set in gather op; valid " + HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } @@ -1934,16 +1933,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "Repeated dimensions not allowed in elided_window_dims in gather op")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Repeated dimensions not allowed in " + "collapsed_slice_dims in gather op")) << statusor.status(); } @@ -1952,17 +1950,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " - "the bound of dimension index_vector_dim=4 of " - "gather_indices is 5. These two numbers must be equal.")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in start_index_map and " + "the bound of dimension index_vector_dim=4 of " + "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1971,16 +1968,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " - "[0, 5), got: 4->7")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } @@ -1989,16 +1984,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } @@ -2007,14 +2001,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{2, 1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 1, 28, 27, 26}); + /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("elided_window_dims in gather op must be sorted")) + HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } @@ -2023,15 +2017,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{2}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 1, 300, 26}); + /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window bound at index 3 in gather op is out of range, " - "must be within [0, 48), got 300")) + HasSubstr("Slice size at index 3 in gather op is out of range, " + "must be within [0, 48), got 300.")) << statusor.status(); } @@ -2040,16 +2034,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26}); + /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Gather op must have one window bound for every input dimension")) + HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } @@ -2058,15 +2051,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26, 20}); + /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only elide window indices with bound 1, " - "but bound is 29 for index 1 at position 0")) + HasSubstr("Gather op can only collapse slice dims with bound 1, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } @@ -2074,16 +2067,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " - "rank(gather_indices) + 1)")) + "rank(start_indices) + 1)")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b77bece85a..f866ed6519 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -30,8 +30,8 @@ using tensorflow::gtl::nullopt; class GatherOperationTest : public HloTestBase { protected: void RunTest(const string& hlo_text, Literal* operand, - Literal* gather_indices) { - RunTest(hlo_text, {operand, gather_indices}); + Literal* start_indices) { + RunTest(hlo_text, {operand, start_indices}); } void RunTest(const string& hlo_text, @@ -52,18 +52,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -74,18 +73,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -96,18 +94,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -118,18 +116,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -140,18 +138,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,1,1,2] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -162,20 +160,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -186,20 +184,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -210,18 +208,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -232,18 +229,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -254,17 +251,16 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -278,19 +274,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -304,19 +300,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = u32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>( + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -330,19 +326,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -356,19 +352,19 @@ ENTRY main { operand = u32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = u32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = u32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -379,17 +375,17 @@ ENTRY main { operand = s32[2,3,2]{2,1,0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), - output_window_dims={0,1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0}, + offset_dims={0,1,2}, + collapsed_slice_dims={}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1,3,2} + slice_sizes={1,3,2} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -400,16 +396,16 @@ ENTRY main { operand = s32[4]{0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[] gather(operand, index), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1} + slice_sizes={1} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -420,17 +416,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[0] parameter(1) ROOT gather = s32[0,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -441,11 +437,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -453,9 +449,8 @@ ENTRY main { )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -466,11 +461,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -478,9 +473,9 @@ ENTRY main { )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -491,11 +486,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -503,9 +498,9 @@ ENTRY main { )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -516,11 +511,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -530,9 +525,9 @@ ENTRY main { LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, @@ -544,11 +539,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -558,9 +553,9 @@ ENTRY main { LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -571,11 +566,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -583,9 +578,8 @@ ENTRY main { )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -596,11 +590,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) @@ -608,9 +602,9 @@ ENTRY main { )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = + std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -622,11 +616,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // ROOT gather = s32[2,3] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0}, - // gather_dims_to_operand_dims={0}, + // offset_dims={1}, + // collapsed_slice_dims={0}, + // start_index_map={0}, // index_vector_dim=1, - // window_bounds={1, 3} + // slice_sizes={1, 3} // } XlaBuilder builder("gather_basic"); @@ -637,9 +631,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { auto operand = Parameter(&builder, 0, operand_shape, "operand"); auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(1); - dim_numbers.add_elided_window_dims(0); - dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); dim_numbers.set_index_vector_dim(1); Gather(operand, indices, dim_numbers, {1, 3}); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 4c35e93d38..27aa94c2cb 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -424,25 +424,25 @@ message GatherDimensionNumbers { // "Window indices" is a term for a set of indices that index into the // interior of a dynamic-slice from the input tensor, the starting indices for // which were computed from output_gather_dims (see the operation semantic for - // how this is defined) and the gather_indices tensor. + // how this is defined) and the start_indices tensor. // // The window indices for a specific output index Out is computed as: // // i = 0 // for (k : [0, input_tensor_shape.rank)) // window_indices[k] = - // if k in elided_window_dims + // if k in collapsed_slice_dims // then 0 - // else Out[output_window_dims[i++]] - repeated int64 output_window_dims = 1; - repeated int64 elided_window_dims = 2; + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; - // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It - // transforms the gather index looked up from the gather_indices tensor into + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into // the starting index in the input space. - repeated int64 gather_dims_to_operand_dims = 3; + repeated int64 start_index_map = 3; - // The dimension in the gather_indices input that contains the starting + // The dimension in the start_indices input that contains the starting // indices. int64 index_vector_dim = 4; } diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 16dd3c5bf3..2de30d1b3d 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1138,7 +1138,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0). ## Gather The XLA gather operation stitches together several slices (each slice at a -potentially different runtime offset) of an input tensor into an output tensor. +potentially different runtime offset) of an input array. ### General Semantics @@ -1146,151 +1146,141 @@ See also [`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). For a more intuitive description, see the "Informal Description" section below. -<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b> +<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b> |Arguments | Type | Semantics | |----------------- | ----------------------- | --------------------------------| -|`operand` | `XlaOp` | The tensor we’re gathering | +|`operand` | `XlaOp` | The array we’re gathering | : : : from. : -|`gather_indices` | `XlaOp` | Tensor containing the starting | -: : : indices of the slices we're : -: : : stitching together into the : -: : : output tensor. : -|`index_vector_dim` | `int64` | The dimension in | -: : : `gather_indices` that contains : -: : : the starting indices. : -|`output_window_dims` | `ArraySlice<int64>` | The set of dimensions in the | -: : : output shape that are _window : -: : : dimensions_ (defined below). : -: : : Not all window dimensions may : -: : : be present in the output shape. : -|`elided_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ | -: : : that are not present in the output shape. : -: : : `window_bounds[i]` must be `1` for all `i` : -: : : in `elided_window_dims`. : -|`window_bounds` | `ArraySlice<int64>` | `window_bounds[i]` is the bounds | -: : : for window dimension `i`. This includes : -: : : both the window dimensions that are : -: : : explicitly part of the output shape (via : -: : : `output_window_dims`) and the window : -: : : dimensions that are elided (via : -: : : `elided_window_dims`). : -|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the | -: : : array is interpreted as mapping `i` to : -: : : `gather_dims_to_operand_dims[i]`) from : -: : : the gather indices in `gather_indices` to : -: : : the operand index space. It has to be : -: : : one-to-one and total. : - -For every index `Out` in the output tensor, we compute two things (more -precisely described later): - - - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`, - which gives us a starting index of a slice, _operand slice_, in the operand - tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions - in `gather_indices` except `index_vector_dim`. - - - A _window index_ that has the same rank as the operand. This index is - composed of the values in `Out` at dimensions `output_window_dims`, embedded - with zeroes according to `elided_window_dims`. - -The _window index_ is the relative index of the element in _operand slice_ that -should be present in the output at index `Out`. - -The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank` -- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type -`ArraySlice<int64>` as the set of dimensions in the output shape but not in -`output_window_dims`, in ascending order. E.g. if the output tensor has rank -`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, -`3`} - -If `index_vector_dim` is equal to `gather_indices.rank` we implicitly -consider `gather_indices` to have a trailing `1` dimension (i.e. if -`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then -we implicitly consider the shape of `gather_indices` to be `[6,7,1]`). - -The bounds for the output tensor along dimension `i` is computed as follows: - - 1. If `i` is present in `output_gather_dims` (i.e. is equal to - `output_gather_dims[k]` for some `k`) then we pick the corresponding - dimension bounds out of `gather_indices.shape`, skipping - `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k` - < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`] - otherwise). - 2. If `i` is present in `output_window_dims` (i.e. equal to - `output_window_dims`[`k`] for some `k`) then we pick the corresponding - bound out of `window_bounds` after accounting for `elided_window_dims` - (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds` - is `window_bounds` with the bounds at indices `elided_window_dims` - removed). - -The operand index `In` corresponding to an output index `Out` is computed as -follows: - - 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice - out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)] - where Combine(A, b) inserts b at position `index_vector_dim` into A. - Note that this is well defined even if `G` is empty -- if `G` is empty then - `S` = `gather_indices`. - 2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by - scattering `S` using the `gather_dims_to_operand_dims` map - (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned - above). More precisely: - 1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` < - `gather_dims_to_operand_dims.size`. +|`start_indices` | `XlaOp` | Array containing the starting | +: : : indices of the slices we gather.: +|`index_vector_dim` | `int64` | The dimension in | +: : : `start_indices` that "contains" : +: : : the starting indices. See : +: : : below for a detailed : +: : : description. : +|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the : +: : : output shape that offset into a : +: : : array sliced from operand. : +|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds | +: : : for the slice on dimension `i`.: +|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each : +| : | slice that are collapsed away. : +| : | These dimensions must have size: +| : | 1. | +|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map| +: : : indices in `start_indices` to : +: : : to legal indices into operand. : + +For convenience, we label dimensions in the output array not in `offset_dims` +as `batch_dims`. + +The output is an array of rank `batch_dims.size` + `operand.rank` - +`collapsed_slice_dims`.size. + +If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider +`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of +shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the +shape of `start_indices` to be `[6,7,1]`). + +The bounds for the output array along dimension `i` is computed as follows: + + 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for + some `k`) then we pick the corresponding dimension bounds out of + `start_indices.shape`, skipping `index_vector_dim` (i.e. pick + `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and + `start_indices.shape.dims`[`k`+`1`] otherwise). + + 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for + some `k`) then we pick the corresponding bound out of `slice_sizes` after + accounting for `collapsed_slice_dims` (i.e. we pick + `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` + with the bounds at indices `collapsed_slice_dims` removed). + +Formally, the operand index `In` corresponding to an output index `Out` is +computed as follows: + + 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out + vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where + Combine(A, b) inserts b at position `index_vector_dim` into A. Note that + this is well defined even if `G` is empty -- if `G` is empty then `S` = + `start_indices`. + + 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by + scattering `S` using `start_index_map`. More precisely: + 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` < + `start_index_map.size`. 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. - 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices - at the output window dimensions in `Out` according to - the `elided_window_dims` set (`W`<sub>`in`</sub> is the _window index_ - mentioned above). More precisely: - 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if - `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is - defined below). - 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise. - 4. `In` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise + + 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices + at the offset dimensions in `Out` according to the `collapsed_slice_dims` + set. More precisely: + 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] = + `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` + (`expand_offset_dims` is defined below). + 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise. + 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise addition. -`window_dims_to_operand_dims` is the monotonic function with domain [`0`, -`output_window_dims.size`) and range [`0`, `operand.rank`) \ -`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`, -`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then -`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. +`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) +and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., +`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`, +`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. ### Informal Description and Examples -`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the -examples that follow. More interesting values for `index_vector_dim` -does not change the operation fundamentally, but makes the visual representation -more cumbersome. +Informally, every index `Out` in the output array corresponds to an element `E` +in the operand array, computed as follows: + + - We use the batch dimensions in `Out` to look up a starting index from + `start_indices`. + + - We use `start_index_map` to map the starting index (which may have size less + than operand.rank) to a "full" starting index into operand. + + - We dynamic-slice out a slice with size `slice_sizes` using the full starting + index. + + - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. + Since all collapsed slice dimensions have to have bound 1 this reshape is + always legal. + + - We use the offset dimensions in `Out` to index into this slice to get the + input element, `E`, corresponding to output index `Out`. + +`index_vector_dim` is set to `start_indices.rank` - `1` in all of the +examples that follow. More interesting values for `index_vector_dim` does not +change the operation fundamentally, but makes the visual representation more +cumbersome. To get an intuition on how all of the above fits together, let's look at an -example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The -position of a slice into the `[16,11]` tensor can be represented as an index +example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The +position of a slice into the `[16,11]` array can be represented as an index vector of shape `S64[2]`, so the set of 5 positions can be represented as a -`S64[5,2]` tensor. +`S64[5,2]` array. The behavior of the gather operation can then be depicted as an index -transformation that takes [`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>], an index in -the output shape, and maps it to an element in the input tensor in the following +transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in +the output shape, and maps it to an element in the input array in the following way: <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="../../images/ops_xla_gather_0.svg"> </div> -We first select an (`X`,`Y`) vector from the gather indices tensor using `G`. -The element in the output tensor at index -[`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>] is then the element in the input -tensor at index [`X`+`W`<sub>`0`</sub>,`Y`+`W`<sub>`1`</sub>]. +We first select an (`X`,`Y`) vector from the gather indices array using `G`. +The element in the output array at index +[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input +array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>]. -`window_bounds` is `[8,6]`, which decides the range of W<sub>`0`</sub> and +`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and W<sub>`1`</sub>, and this in turn decides the bounds of the slice. This gather operation acts as a batch dynamic slice with `G` as the batch dimension. The gather indices may be multidimensional. For instance, a more general -version of the example above using a "gather indices" tensor of shape `[4,5,2]` +version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this: <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> @@ -1298,25 +1288,25 @@ would translate indices like this: </div> Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and -`G`<sub>`1`</sub> as the batch dimensions. The window bounds are still `[8,6]`. +`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`. The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the window - dimensions (dimensions containing `W`<sub>`0`</sub>, `W`<sub>`1`</sub> in - the last example). The output gather dimensions (dimensions containing + 1. We can configure which dimensions in the output shape are the offset + dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in + the last example). The output batch dimensions (dimensions containing `G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be - the output dimensions that are not window dimensions. + the output dimensions that are not offset dimensions. - 2. The number of output window dimensions explicitly present in the output + 2. The number of output offset dimensions explicitly present in the output shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `elided_window_dims`, must have a window bound of - `1`. Since they have a window bound of `1` the only valid index for them is + are listed explicitly as `collapsed_slice_dims`, must have a slice size of + `1`. Since they have a slice size of `1` the only valid index for them is `0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last - example) may have fewer elements than the input tensor rank, and an explicit + 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last + example) may have fewer elements than the input array rank, and an explicit mapping dictates how the index should be expanded to have the same rank as the input. @@ -1327,20 +1317,19 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`: </div> `G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index -from the gather indices tensor as usual, except the starting index has only one -element, `X`. Similarly, there is only one output window index with the value -`W`<sub>`0`</sub>. However, before being used as indices into the input tensor, -these are expanded in accordance to "Gather Index Mapping" -(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping" -(`window_dims_to_operand_dims` in the formal description) into -[`0`,`W`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up to -[`X`,`W`<sub>`0`</sub>]. In other words, the output index -[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`W`<sub>`0`</sub>] maps to the input index +from the gather indices array as usual, except the starting index has only one +element, `X`. Similarly, there is only one output offset index with the value +`O`<sub>`0`</sub>. However, before being used as indices into the input array, +these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in +the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal +description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up +to [`X`,`O`<sub>`0`</sub>]. In other words, the output index +[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index [`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us the semantics for `tf.gather_nd`. -`window_bounds` for this case is `[1,11]`. Intuitively this means that every -index `X` in the gather indices tensor picks an entire row and the result is the +`slice_sizes` for this case is `[1,11]`. Intuitively this means that every +index `X` in the gather indices array picks an entire row and the result is the concatenation of all these rows. ## GetTupleElement |