From d43820b9eff0cc863de2bbfb142afe92bf5afd00 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 16 Aug 2018 14:44:15 -0700 Subject: Improve gather ergonomics by renaming fields. This CL renames the various inputs to the Gather HLO to be more mnemonic by making it more obviously a batch dynamic-slice. The replacements I made are: s/elided_window_dims/collapsed_slice_dims/g s/window_bounds/slice_sizes/g s/gather_dims_to_operand_dims/start_index_map/g s/gather_indices/start_indices/g s/output_window_dims/offset_dims/g PiperOrigin-RevId: 209051067 --- tensorflow/compiler/xla/service/shape_inference.cc | 201 ++++++++++----------- 1 file changed, 98 insertions(+), 103 deletions(-) (limited to 'tensorflow/compiler/xla/service/shape_inference.cc') diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec5743a777..cc1ec1704e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2492,201 +2492,196 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, static Status ValidateGatherDimensionNumbers( const Shape& input_shape, - tensorflow::gtl::ArraySlice gather_indices_shape, + tensorflow::gtl::ArraySlice start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.output_window_dims())) { + if (!c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.output_window_dims()) != - dim_numbers.output_window_dims().end()) { + if (c_adjacent_find(dim_numbers.offset_dims()) != + dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size() - 1; + output_offset_dim_count + start_indices_shape.size() - 1; - for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { - int64 window_index = dim_numbers.output_window_dims(i); - if (window_index < 0 || window_index >= output_shape_rank) { + for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { + int64 offset_dim = dim_numbers.offset_dims(i); + if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Window index %d in gather op is out of bounds; got %lld, but should " + "Offset dimension %d in gather op is out of bounds; got %lld, but " + "should " "have been in [0,%lld).", - i, window_index, output_shape_rank); + i, offset_dim, output_shape_rank); } } - if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape[dim_numbers.index_vector_dim()]) { + if (dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "Gather op has %d elements in gather_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of gather_indices is " + "Gather op has %d elements in start_index_map and the " + "bound of dimension index_vector_dim=%lld of start_indices is " "%lld. These two numbers must be equal.", - dim_numbers.gather_dims_to_operand_dims_size(), - dim_numbers.index_vector_dim(), - gather_indices_shape[dim_numbers.index_vector_dim()]); + dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), + start_indices_shape[dim_numbers.index_vector_dim()]); } - for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { - int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); - if (gather_dim_to_input_dim < 0 || - gather_dim_to_input_dim >= input_shape.dimensions_size()) { + for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { + int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i); + if (operand_dim_for_start_index_i < 0 || + operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", - input_shape.dimensions_size(), i, gather_dim_to_input_dim); + "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } - std::vector sorted_gather_dims_to_operand_dims( - dim_numbers.gather_dims_to_operand_dims().begin(), - dim_numbers.gather_dims_to_operand_dims().end()); + std::vector sorted_start_index_map( + dim_numbers.start_index_map().begin(), + dim_numbers.start_index_map().end()); - c_sort(sorted_gather_dims_to_operand_dims); + c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != - sorted_gather_dims_to_operand_dims.end()) { + if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) { return InvalidArgument( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + Join(dim_numbers.start_index_map(), ", ").c_str()); } - for (int64 elided_dim : dim_numbers.elided_window_dims()) { - if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { + if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid elided_window_dims set in gather op; valid range is [0, " + "Invalid collapsed_slice_dims set in gather op; valid range is [0, " "%d), got: %lld.", - input_shape.dimensions_size(), elided_dim); + input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.elided_window_dims())) { + if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( - "elided_window_dims in gather op must be sorted; got: %s", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + "collapsed_slice_dims in gather op must be sorted; got: %s", + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.elided_window_dims()) != - dim_numbers.elided_window_dims().end()) { + if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( - "Repeated dimensions not allowed in elided_window_dims in gather op; " + "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } return Status::OK(); } /*static*/ StatusOr ShapeInference::InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + tensorflow::gtl::ArraySlice slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( - ExpectArray(gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(start_indices_shape, "gather indices operand of gather op")); - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape).c_str()); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if // index_vector_dim is rank(P). The bounds of this expanded shape is - // stored in expanded_gather_indices_shape. + // stored in expanded_start_indices_shape. - if (gather_indices_shape.dimensions_size() < + if (start_indices_shape.dimensions_size() < gather_dim_numbers.index_vector_dim() || gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather index leaf dimension must be within [0, rank(gather_indices) + " - "1). rank(gather_indices) is %d and gather index leaf dimension is " + "Gather index leaf dimension must be within [0, rank(start_indices) + " + "1). rank(start_indices) is %d and gather index leaf dimension is " "%lld.", - gather_indices_shape.dimensions_size(), + start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } - std::vector expanded_gather_indices_shape; - expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == + std::vector expanded_start_indices_shape; + expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); + c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); + if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { - expanded_gather_indices_shape.push_back(1); + expanded_start_indices_shape.push_back(1); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( - input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + input_shape, expanded_start_indices_shape, gather_dim_numbers)); - if (window_bounds.size() != input_shape.dimensions_size()) { + if (slice_sizes.size() != input_shape.dimensions_size()) { return InvalidArgument( - "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d.", - window_bounds.size(), input_shape.dimensions_size()); + "Gather op must have one slice size for every input dimension; got: " + "len(slice_sizes)=%lu, input_shape.rank=%d.", + slice_sizes.size(), input_shape.dimensions_size()); } - if (window_bounds.size() != - gather_dim_numbers.output_window_dims_size() + - gather_dim_numbers.elided_window_dims_size()) { + if (slice_sizes.size() != + gather_dim_numbers.offset_dims_size() + + gather_dim_numbers.collapsed_slice_dims_size()) { return InvalidArgument( - "All components of the window index in a gather op must either be a " - "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s.", - window_bounds.size(), - Join(gather_dim_numbers.output_window_dims(), ",").c_str(), - Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + "All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " + "output_slice_sizes=%s, collapsed_slice_dims=%s.", + slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), + Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); } - for (int i = 0; i < window_bounds.size(); i++) { - int64 window_bound = window_bounds[i]; - int64 corresponding_input_bound = input_shape.dimensions(i); - if (window_bound < 0 || window_bound > corresponding_input_bound) { + for (int i = 0; i < slice_sizes.size(); i++) { + int64 slice_size = slice_sizes[i]; + int64 corresponding_input_size = input_shape.dimensions(i); + if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( - "Window bound at index %d in gather op is out of range, must be " - "within " - "[0, %lld), got %lld.", - i, corresponding_input_bound + 1, window_bound); + "Slice size at index %d in gather op is out of range, must be " + "within [0, %lld), got %lld.", + i, corresponding_input_size + 1, slice_size); } } - for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { - if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( - "Gather op can only elide window indices with bound 1, but bound is " + "Gather op can only collapse slice dims with bound 1, but bound is " "%lld for index %lld at position %d.", - window_bounds[gather_dim_numbers.elided_window_dims(i)], - gather_dim_numbers.elided_window_dims(i), i); + slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], + gather_dim_numbers.collapsed_slice_dims(i), i); } } - int64 result_rank = gather_dim_numbers.output_window_dims_size() + - (expanded_gather_indices_shape.size() - 1); - int64 window_dims_seen = 0; + int64 result_rank = gather_dim_numbers.offset_dims_size() + + (expanded_start_indices_shape.size() - 1); + int64 offset_dims_seen = 0; int64 gather_dims_seen = 0; std::vector output_dim_bounds; output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; - bool is_window_index = - c_binary_search(gather_dim_numbers.output_window_dims(), i); + bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.elided_window_dims(), - window_dims_seen)) { - window_dims_seen++; + while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { + offset_dims_seen++; } - current_bound = window_bounds[window_dims_seen++]; + current_bound = slice_sizes[offset_dims_seen++]; } else { if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { gather_dims_seen++; } - current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_bounds.push_back(current_bound); @@ -2837,25 +2832,25 @@ Status ValidateScatterDimensionNumbers( scatter_dim_numbers)); int64 inserted_dims_seen = 0; - std::vector max_update_window_bounds; + std::vector max_update_slice_sizes; for (int i = 0; i < operand_shape.dimensions_size(); ++i) { if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; } else { - max_update_window_bounds.push_back(operand_shape.dimensions(i)); + max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } } for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { auto update_window_dim = scatter_dim_numbers.update_window_dims(i); if (updates_shape.dimensions(update_window_dim) > - max_update_window_bounds[i]) { + max_update_slice_sizes[i]) { return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " "%lld, updates bound is %lld, operand bound is %lld.", update_window_dim, updates_shape.dimensions(update_window_dim), - max_update_window_bounds[i]); + max_update_slice_sizes[i]); } } -- cgit v1.2.3