diff options
author | 2018-08-30 16:03:10 -0700 | |
---|---|---|
committer | 2018-08-30 16:07:27 -0700 | |
commit | 6f879f891abe2e267c5cf512d034d7c3641cfdb0 (patch) | |
tree | 33dfda2aa13bdec06d3aa330dd5816441d449fa7 /tensorflow/compiler/xla/service/shape_inference.h | |
parent | 5d5591fbd4624ff7e50f305464667315f2d41ebb (diff) |
[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 64 |
1 files changed, 29 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 235b1a4cf3..072ada2d8f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -55,7 +55,7 @@ class ShapeInference { // given input shapes. static StatusOr<Shape> InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + absl::Span<const int64> broadcast_dimensions); static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -73,18 +73,15 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr<Shape> InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + HloOpcode opcode, absl::Span<const Shape* const> operand_shapes); static StatusOr<Shape> InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice<const HloInstruction*> operands); + HloOpcode opcode, absl::Span<const HloInstruction* const> operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. static StatusOr<Shape> InferMapShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice<int64> dimensions); + absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, + absl::Span<const int64> dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. @@ -116,14 +113,13 @@ class ShapeInference { int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr<Shape> InferFftShape( - const Shape& in, FftType fft_type, - tensorflow::gtl::ArraySlice<int64> fft_length); + static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, + absl::Span<const int64> fft_length); // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr<Shape> InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + absl::Span<const Shape* const> operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla // builder. @@ -134,7 +130,7 @@ class ShapeInference { // Infers the shape of an HLO all-to-all instruction. static StatusOr<Shape> InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + absl::Span<const Shape* const> operand_shapes); // Infers the shape of a collective permute operation. static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape); @@ -146,8 +142,8 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr<Shape> InferReduceShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, + absl::Span<const Shape* const> arg_shapes, + absl::Span<const int64> dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand @@ -165,24 +161,23 @@ class ShapeInference { // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr<Shape> InferReverseShape( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice<int64> dimensions); + static StatusOr<Shape> InferReverseShape(const Shape& operand_shape, + absl::Span<const int64> dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr<Shape> InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts, - tensorflow::gtl::ArraySlice<int64> limits, - tensorflow::gtl::ArraySlice<int64> strides); + static StatusOr<Shape> InferSliceShape(const Shape& arg, + absl::Span<const int64> starts, + absl::Span<const int64> limits, + absl::Span<const int64> strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr<Shape> InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice<int64> slice_sizes); + absl::Span<const int64> slice_sizes); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -213,30 +208,30 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static StatusOr<Shape> InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + const Shape& operand, absl::Span<const int64> broadcast_sizes); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr<Shape> InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<int64> new_sizes); + static StatusOr<Shape> InferReshapeShape(const Shape& operand, + absl::Span<const int64> dimensions, + absl::Span<const int64> new_sizes); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr<Shape> InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions); + const Shape& operand, absl::Span<const int64> dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. static StatusOr<Shape> InferConcatOpShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension); + absl::Span<const Shape* const> arg_shapes, int64 dimension); // Infers the shape produced by a kAfterAll. Trivially this shape is always a // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes // and checking operand shapes. This method verifies that the operand shapes // are all TOKENs. static StatusOr<Shape> InferAfterAllShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes); + absl::Span<const Shape* const> arg_shapes); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that @@ -266,8 +261,7 @@ class ShapeInference { // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. static StatusOr<Shape> InferCallShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply); + absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. @@ -281,7 +275,7 @@ class ShapeInference { static StatusOr<Shape> InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> slice_sizes); + absl::Span<const int64> slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, @@ -299,7 +293,7 @@ class ShapeInference { // even in the presence of broadcasting of one of the operands over the other. static StatusOr<Shape> InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + absl::Span<const int64> broadcast_dimensions); // Helper for inferring the shape of Clamp ops. static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, @@ -327,7 +321,7 @@ class ShapeInference { // smaller_shape is broadcast to. static StatusOr<Shape> InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + absl::Span<const int64> broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); }; |