aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-16 14:44:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 14:47:49 -0700
commitd43820b9eff0cc863de2bbfb142afe92bf5afd00 (patch)
treee019bdc0bb52496436e0ff92d76b728165e3a420 /tensorflow/compiler/xla/service/shape_inference.cc
parent8235e83c442744a1285ce97e5dfc2a6556f9f667 (diff)
Improve gather ergonomics by renaming fields.
This CL renames the various inputs to the Gather HLO to be more mnemonic by making it more obviously a batch dynamic-slice. The replacements I made are: s/elided_window_dims/collapsed_slice_dims/g s/window_bounds/slice_sizes/g s/gather_dims_to_operand_dims/start_index_map/g s/gather_indices/start_indices/g s/output_window_dims/offset_dims/g PiperOrigin-RevId: 209051067
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc201
1 files changed, 98 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec5743a777..cc1ec1704e 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2492,201 +2492,196 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
static Status ValidateGatherDimensionNumbers(
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ if (!c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.output_window_dims()) !=
- dim_numbers.output_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.offset_dims()) !=
+ dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size() - 1;
+ output_offset_dim_count + start_indices_shape.size() - 1;
- for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
- int64 window_index = dim_numbers.output_window_dims(i);
- if (window_index < 0 || window_index >= output_shape_rank) {
+ for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
+ int64 offset_dim = dim_numbers.offset_dims(i);
+ if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Window index %d in gather op is out of bounds; got %lld, but should "
+ "Offset dimension %d in gather op is out of bounds; got %lld, but "
+ "should "
"have been in [0,%lld).",
- i, window_index, output_shape_rank);
+ i, offset_dim, output_shape_rank);
}
}
- if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape[dim_numbers.index_vector_dim()]) {
+ if (dim_numbers.start_index_map_size() !=
+ start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "Gather op has %d elements in gather_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of gather_indices is "
+ "Gather op has %d elements in start_index_map and the "
+ "bound of dimension index_vector_dim=%lld of start_indices is "
"%lld. These two numbers must be equal.",
- dim_numbers.gather_dims_to_operand_dims_size(),
- dim_numbers.index_vector_dim(),
- gather_indices_shape[dim_numbers.index_vector_dim()]);
+ dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
+ start_indices_shape[dim_numbers.index_vector_dim()]);
}
- for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
- int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
- if (gather_dim_to_input_dim < 0 ||
- gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
+ int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
+ if (operand_dim_for_start_index_i < 0 ||
+ operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
- input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ "Invalid start_index_map; domain is [0, %d), got: %d->%lld.",
+ input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
- std::vector<int64> sorted_gather_dims_to_operand_dims(
- dim_numbers.gather_dims_to_operand_dims().begin(),
- dim_numbers.gather_dims_to_operand_dims().end());
+ std::vector<int64> sorted_start_index_map(
+ dim_numbers.start_index_map().begin(),
+ dim_numbers.start_index_map().end());
- c_sort(sorted_gather_dims_to_operand_dims);
+ c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
- sorted_gather_dims_to_operand_dims.end()) {
+ if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) {
return InvalidArgument(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ Join(dim_numbers.start_index_map(), ", ").c_str());
}
- for (int64 elided_dim : dim_numbers.elided_window_dims()) {
- if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
+ if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
"%d), got: %lld.",
- input_shape.dimensions_size(), elided_dim);
+ input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
- "elided_window_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ "collapsed_slice_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
- dim_numbers.elided_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
- "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
return Status::OK();
}
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
- ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(start_indices_shape, "gather indices operand of gather op"));
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape).c_str());
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
// index_vector_dim is rank(P). The bounds of this expanded shape is
- // stored in expanded_gather_indices_shape.
+ // stored in expanded_start_indices_shape.
- if (gather_indices_shape.dimensions_size() <
+ if (start_indices_shape.dimensions_size() <
gather_dim_numbers.index_vector_dim() ||
gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather index leaf dimension must be within [0, rank(gather_indices) + "
- "1). rank(gather_indices) is %d and gather index leaf dimension is "
+ "Gather index leaf dimension must be within [0, rank(start_indices) + "
+ "1). rank(start_indices) is %d and gather index leaf dimension is "
"%lld.",
- gather_indices_shape.dimensions_size(),
+ start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
- std::vector<int64> expanded_gather_indices_shape;
- expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() ==
+ std::vector<int64> expanded_start_indices_shape;
+ expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
+ c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
+ if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
- expanded_gather_indices_shape.push_back(1);
+ expanded_start_indices_shape.push_back(1);
}
TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
- input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+ input_shape, expanded_start_indices_shape, gather_dim_numbers));
- if (window_bounds.size() != input_shape.dimensions_size()) {
+ if (slice_sizes.size() != input_shape.dimensions_size()) {
return InvalidArgument(
- "Gather op must have one window bound for every input dimension; got: "
- "len(window_bounds)=%lu, input_shape.rank=%d.",
- window_bounds.size(), input_shape.dimensions_size());
+ "Gather op must have one slice size for every input dimension; got: "
+ "len(slice_sizes)=%lu, input_shape.rank=%d.",
+ slice_sizes.size(), input_shape.dimensions_size());
}
- if (window_bounds.size() !=
- gather_dim_numbers.output_window_dims_size() +
- gather_dim_numbers.elided_window_dims_size()) {
+ if (slice_sizes.size() !=
+ gather_dim_numbers.offset_dims_size() +
+ gather_dim_numbers.collapsed_slice_dims_size()) {
return InvalidArgument(
- "All components of the window index in a gather op must either be a "
- "output window index or explicitly elided; got len(window_bounds)=%lu, "
- "output_window_bounds=%s, elided_window_bounds=%s.",
- window_bounds.size(),
- Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
- Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ "All components of the offset index in a gather op must either be a "
+ "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
+ "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
+ Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
}
- for (int i = 0; i < window_bounds.size(); i++) {
- int64 window_bound = window_bounds[i];
- int64 corresponding_input_bound = input_shape.dimensions(i);
- if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ for (int i = 0; i < slice_sizes.size(); i++) {
+ int64 slice_size = slice_sizes[i];
+ int64 corresponding_input_size = input_shape.dimensions(i);
+ if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
- "Window bound at index %d in gather op is out of range, must be "
- "within "
- "[0, %lld), got %lld.",
- i, corresponding_input_bound + 1, window_bound);
+ "Slice size at index %d in gather op is out of range, must be "
+ "within [0, %lld), got %lld.",
+ i, corresponding_input_size + 1, slice_size);
}
}
- for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
- if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
- "Gather op can only elide window indices with bound 1, but bound is "
+ "Gather op can only collapse slice dims with bound 1, but bound is "
"%lld for index %lld at position %d.",
- window_bounds[gather_dim_numbers.elided_window_dims(i)],
- gather_dim_numbers.elided_window_dims(i), i);
+ slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
+ gather_dim_numbers.collapsed_slice_dims(i), i);
}
}
- int64 result_rank = gather_dim_numbers.output_window_dims_size() +
- (expanded_gather_indices_shape.size() - 1);
- int64 window_dims_seen = 0;
+ int64 result_rank = gather_dim_numbers.offset_dims_size() +
+ (expanded_start_indices_shape.size() - 1);
+ int64 offset_dims_seen = 0;
int64 gather_dims_seen = 0;
std::vector<int64> output_dim_bounds;
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
- bool is_window_index =
- c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.elided_window_dims(),
- window_dims_seen)) {
- window_dims_seen++;
+ while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
}
- current_bound = window_bounds[window_dims_seen++];
+ current_bound = slice_sizes[offset_dims_seen++];
} else {
if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
gather_dims_seen++;
}
- current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ current_bound = expanded_start_indices_shape[gather_dims_seen++];
}
output_dim_bounds.push_back(current_bound);
@@ -2837,25 +2832,25 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_numbers));
int64 inserted_dims_seen = 0;
- std::vector<int64> max_update_window_bounds;
+ std::vector<int64> max_update_slice_sizes;
for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
++inserted_dims_seen;
} else {
- max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ max_update_slice_sizes.push_back(operand_shape.dimensions(i));
}
}
for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
if (updates_shape.dimensions(update_window_dim) >
- max_update_window_bounds[i]) {
+ max_update_slice_sizes[i]) {
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
"%lld, updates bound is %lld, operand bound is %lld.",
update_window_dim, updates_shape.dimensions(update_window_dim),
- max_update_window_bounds[i]);
+ max_update_slice_sizes[i]);
}
}