aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.h
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-30 16:03:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 16:07:27 -0700
commit6f879f891abe2e267c5cf512d034d7c3641cfdb0 (patch)
tree33dfda2aa13bdec06d3aa330dd5816441d449fa7 /tensorflow/compiler/xla/service/shape_inference.h
parent5d5591fbd4624ff7e50f305464667315f2d41ebb (diff)
[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h64
1 files changed, 29 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 235b1a4cf3..072ada2d8f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -55,7 +55,7 @@ class ShapeInference {
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
const HloInstruction* lhs,
const HloInstruction* rhs);
@@ -73,18 +73,15 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
// Infers the shape produced by applying the given mapping computation shape
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.
@@ -116,14 +113,13 @@ class ShapeInference {
int64 feature_group_count = 1);
// Infers the shape produced by the given FFT type on the given operand.
- static StatusOr<Shape> InferFftShape(
- const Shape& in, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
+ absl::Span<const int64> fft_length);
// Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers final shape of an Alltoall operation that is created by the xla
// builder.
@@ -134,7 +130,7 @@ class ShapeInference {
// Infers the shape of an HLO all-to-all instruction.
static StatusOr<Shape> InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers the shape of a collective permute operation.
static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
@@ -146,8 +142,8 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply);
// Infers the shape produced by applying the given computation to the operand
@@ -165,24 +161,23 @@ class ShapeInference {
// Infers the shape produced by a reverse operation that reverses the order
// of the elements in the given dimensions.
- static StatusOr<Shape> InferReverseShape(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by a slice operation spanning from the starts to
// the limits in the original shape's dimensions.
//
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
- static StatusOr<Shape> InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides);
+ static StatusOr<Shape> InferSliceShape(const Shape& arg,
+ absl::Span<const int64> starts,
+ absl::Span<const int64> limits,
+ absl::Span<const int64> strides);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
static StatusOr<Shape> InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Infers the shape produced by a dynamic update slice operation based
// on the shape of operand and update.
@@ -213,30 +208,30 @@ class ShapeInference {
// Infers the shape produced by a broadcast operation.
static StatusOr<Shape> InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ const Shape& operand, absl::Span<const int64> broadcast_sizes);
// Infers the shape produced by a reshape operation from the element type of
// its operand and the new dimension sizes specified.
- static StatusOr<Shape> InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ static StatusOr<Shape> InferReshapeShape(const Shape& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Infers the shape produced by a transpose operation from the element type of
// its operand and its dimensions field.
static StatusOr<Shape> InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+ const Shape& operand, absl::Span<const int64> dimensions);
// Helper that infers the shape produced by performing a concatenate operation
// with the given operand shapes.
static StatusOr<Shape> InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ absl::Span<const Shape* const> arg_shapes, int64 dimension);
// Infers the shape produced by a kAfterAll. Trivially this shape is always a
// TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
// and checking operand shapes. This method verifies that the operand shapes
// are all TOKENs.
static StatusOr<Shape> InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+ absl::Span<const Shape* const> arg_shapes);
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
@@ -266,8 +261,7 @@ class ShapeInference {
// Helper that validates the given arg_shapes are compatible with the shape of
// the to_apply parameters, and returns the to_apply result shape.
static StatusOr<Shape> InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
@@ -281,7 +275,7 @@ class ShapeInference {
static StatusOr<Shape> InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
@@ -299,7 +293,7 @@ class ShapeInference {
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
@@ -327,7 +321,7 @@ class ShapeInference {
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
};