diff options
author | Tim Shen <timshen@google.com> | 2018-08-30 16:03:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 16:07:27 -0700 |
commit | 6f879f891abe2e267c5cf512d034d7c3641cfdb0 (patch) | |
tree | 33dfda2aa13bdec06d3aa330dd5816441d449fa7 /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | 5d5591fbd4624ff7e50f305464667315f2d41ebb (diff) |
[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 89 |
1 files changed, 40 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 45427bba25..2611749862 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -45,7 +45,7 @@ using absl::StrFormat; using absl::StrJoin; // Returns true if no element is present in slice more than once. -bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) { +bool AllUnique(absl::Span<const int64> slice) { return std::set<int64>(slice.begin(), slice.end()).size() == slice.size(); } @@ -57,11 +57,10 @@ Status ExpectArray(const Shape& shape, absl::string_view op_type) { return Status::OK(); } -Status VerifyReducerShape( - const ProgramShape& reducer_shape, - tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes, - tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types, - int64 inputs) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + absl::Span<const Shape* const> init_value_shapes, + absl::Span<const PrimitiveType> input_element_types, + int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( "Reduction function must take %d parameters, but " @@ -335,8 +334,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const int64 dimension) { + absl::Span<const Shape* const> arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } @@ -394,7 +392,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) { + absl::Span<const Shape* const> arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { return InvalidArgument( @@ -550,22 +548,22 @@ Status ValidateDotDimensionNumbers( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { // Check that dimension numbers are in range. - auto dims_in_range = - [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims, - tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool { + auto dims_in_range = [](const int64 rank, + absl::Span<const int64> contracting_dims, + absl::Span<const int64> batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; return std::all_of(contracting_dims.begin(), contracting_dims.end(), in_range) && std::all_of(batch_dims.begin(), batch_dims.end(), in_range); }; - tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions = + absl::Span<const int64> lhs_contracting_dimensions = AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions = + absl::Span<const int64> rhs_contracting_dimensions = AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions = + absl::Span<const int64> lhs_batch_dimensions = AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); - tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions = + absl::Span<const int64> rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, @@ -577,8 +575,8 @@ Status ValidateDotDimensionNumbers( } // Check that dimension numbers are unique. - auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims, - tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool { + auto dims_unique = [](absl::Span<const int64> contracting_dims, + absl::Span<const int64> batch_dims) -> bool { tensorflow::gtl::FlatSet<int64> dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; @@ -748,7 +746,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + absl::Span<const int64> broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. @@ -849,7 +847,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + absl::Span<const int64> broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); @@ -906,7 +904,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + absl::Span<const int64> broadcast_dimensions) { VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), @@ -1005,8 +1003,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice<const HloInstruction*> operands) { + HloOpcode opcode, absl::Span<const HloInstruction* const> operands) { std::vector<const Shape*> operand_shapes; operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { @@ -1016,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { + HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } @@ -1053,9 +1049,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferMapShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice<int64> dimensions) { + absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, + absl::Span<const int64> dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument."); } @@ -1711,7 +1706,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, - const tensorflow::gtl::ArraySlice<int64> fft_length) { + const absl::Span<const int64> fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); @@ -1792,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { + absl::Span<const Shape* const> operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); @@ -1835,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { + absl::Span<const Shape* const> operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. TF_RET_CHECK(!operand_shapes.empty()); @@ -1859,8 +1854,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, + absl::Span<const Shape* const> arg_shapes, + absl::Span<const int64> dimensions_to_reduce, const ProgramShape& to_apply) { if (arg_shapes.empty()) { return InvalidArgument("Reduce must have at least 2 arguments, has 0"); @@ -1998,9 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts, - tensorflow::gtl::ArraySlice<int64> limits, - tensorflow::gtl::ArraySlice<int64> strides) { + const Shape& arg, absl::Span<const int64> starts, + absl::Span<const int64> limits, absl::Span<const int64> strides) { auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " @@ -2062,7 +2056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice<int64> slice_sizes) { + absl::Span<const int64> slice_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); @@ -2189,7 +2183,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /*static */ StatusOr<Shape> ShapeInference::InferReverseShape( - const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) { + const Shape& operand_shape, absl::Span<const int64> dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); @@ -2315,7 +2309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { + const Shape& operand, absl::Span<const int64> broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { @@ -2333,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<int64> new_sizes) { + const Shape& operand, absl::Span<const int64> dimensions, + absl::Span<const int64> new_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = @@ -2366,7 +2360,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { + const Shape& operand, absl::Span<const int64> dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector<int64> indices(ShapeUtil::Rank(operand)); @@ -2471,8 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferCallShape( - tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply) { + absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); @@ -2505,8 +2498,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } static Status ValidateGatherDimensionNumbers( - const Shape& input_shape, - tensorflow::gtl::ArraySlice<int64> start_indices_shape, + const Shape& input_shape, absl::Span<const int64> start_indices_shape, const GatherDimensionNumbers& dim_numbers) { if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( @@ -2599,7 +2591,7 @@ static Status ValidateGatherDimensionNumbers( /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> slice_sizes) { + absl::Span<const int64> slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( @@ -2709,8 +2701,7 @@ static Status ValidateGatherDimensionNumbers( namespace { Status ValidateScatterDimensionNumbers( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice<int64> scatter_indices_shape, + const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { |