diff options
author | 2018-08-01 21:44:02 -0700 | |
---|---|---|
committer | 2018-08-01 21:48:45 -0700 | |
commit | e3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch) | |
tree | c4fe57f956fe9334f1e27342a4846250e915fe4a /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | 3379bae787d73d6db67d66a284bd1a076b2cbdba (diff) |
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 35df792b07..20314ca482 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2568,4 +2568,192 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice<int64> scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector<int64> sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + c_sort(sorted_scatter_dims_to_operand_dims); + if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr<Shape> ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + TF_RETURN_IF_ERROR(VerifyReducerShape( + to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}), + updates_shape.element_type())); + + std::vector<int64> expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector<int64> max_update_window_bounds; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_window_bounds.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_window_bounds[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_window_bounds[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla |