aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-30 16:03:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 16:07:27 -0700
commit6f879f891abe2e267c5cf512d034d7c3641cfdb0 (patch)
tree33dfda2aa13bdec06d3aa330dd5816441d449fa7 /tensorflow/compiler/xla/service/shape_inference.cc
parent5d5591fbd4624ff7e50f305464667315f2d41ebb (diff)
[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc89
1 files changed, 40 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 45427bba25..2611749862 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -45,7 +45,7 @@ using absl::StrFormat;
using absl::StrJoin;
// Returns true if no element is present in slice more than once.
-bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+bool AllUnique(absl::Span<const int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
@@ -57,11 +57,10 @@ Status ExpectArray(const Shape& shape, absl::string_view op_type) {
return Status::OK();
}
-Status VerifyReducerShape(
- const ProgramShape& reducer_shape,
- tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
- int64 inputs) {
+Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ absl::Span<const Shape* const> init_value_shapes,
+ absl::Span<const PrimitiveType> input_element_types,
+ int64 inputs) {
if (reducer_shape.parameters_size() != inputs * 2) {
return InvalidArgument(
"Reduction function must take %d parameters, but "
@@ -335,8 +334,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const int64 dimension) {
+ absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
if (arg_shapes.empty()) {
return InvalidArgument("Concatenate expects at least one argument.");
}
@@ -394,7 +392,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ absl::Span<const Shape* const> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
return InvalidArgument(
@@ -550,22 +548,22 @@ Status ValidateDotDimensionNumbers(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
// Check that dimension numbers are in range.
- auto dims_in_range =
- [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_in_range = [](const int64 rank,
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
in_range) &&
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
};
- tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
+ absl::Span<const int64> lhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
+ absl::Span<const int64> rhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
+ absl::Span<const int64> lhs_batch_dimensions =
AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
+ absl::Span<const int64> rhs_batch_dimensions =
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
@@ -577,8 +575,8 @@ Status ValidateDotDimensionNumbers(
}
// Check that dimension numbers are unique.
- auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_unique = [](absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
tensorflow::gtl::FlatSet<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
@@ -748,7 +746,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
// the user to provide an explicit broadcast dimension in this case.
@@ -849,7 +847,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
@@ -906,7 +904,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
VLOG(2) << StrFormat(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
@@ -1005,8 +1003,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
std::vector<const Shape*> operand_shapes;
operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
@@ -1016,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
@@ -1053,9 +1049,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions) {
if (arg_shapes.empty()) {
return InvalidArgument("Map expects at least one argument.");
}
@@ -1711,7 +1706,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
const Shape& in, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
const int64 fft_rank = fft_length.size();
if (fft_rank < 1 || fft_rank > 3) {
return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
@@ -1792,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
ExpectArray(*operand_shape, "operand of cross replica sum"));
@@ -1835,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
// An Alltoall HLO instruction receives N operands (with the same shape) and
// returns a tuple that contains N array shapes.
TF_RET_CHECK(!operand_shapes.empty());
@@ -1859,8 +1854,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
if (arg_shapes.empty()) {
return InvalidArgument("Reduce must have at least 2 arguments, has 0");
@@ -1998,9 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ const Shape& arg, absl::Span<const int64> starts,
+ absl::Span<const int64> limits, absl::Span<const int64> strides) {
auto error = [&](const string& message) {
return InvalidArgument(
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
@@ -2062,7 +2056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
@@ -2189,7 +2183,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
- const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand_shape, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
@@ -2315,7 +2309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const Shape& operand, absl::Span<const int64> broadcast_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
@@ -2333,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ const Shape& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
@@ -2366,7 +2360,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2471,8 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
// The applied function's arity equals the number of arguments.
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
@@ -2505,8 +2498,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
static Status ValidateGatherDimensionNumbers(
- const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices_shape,
+ const Shape& input_shape, absl::Span<const int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
@@ -2599,7 +2591,7 @@ static Status ValidateGatherDimensionNumbers(
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
@@ -2709,8 +2701,7 @@ static Status ValidateGatherDimensionNumbers(
namespace {
Status ValidateScatterDimensionNumbers(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {