aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-30 11:20:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 11:29:03 -0700
commit09f2c342c0e74834bebf8045d5e77dcef8323539 (patch)
tree3a969695f54840366629df7f39e4da7f9ac103cd /tensorflow
parent9e12f1df3270b5e0b310645e6c3cae9fbd3f5dfc (diff)
Remove (Mutable)ArraySlice implementation and alias them to absl::Span.
There are several API migrations happening: * ArraySlice's sub-slice constructor => .subspan * MutableArraySlice's container pointer constructor => absl::MakeSpan PiperOrigin-RevId: 210946124
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc3
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc7
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc14
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc8
-rw-r--r--tensorflow/compiler/xla/index_util_test.cc8
-rw-r--r--tensorflow/compiler/xla/literal.cc6
-rw-r--r--tensorflow/compiler/xla/literal.h2
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc27
-rw-r--r--tensorflow/compiler/xla/literal_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc15
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc6
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.h4
-rw-r--r--tensorflow/compiler/xla/sparse_index_array_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc8
-rw-r--r--tensorflow/compiler/xla/util.h18
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc6
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc6
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc22
-rw-r--r--tensorflow/core/kernels/set_kernels.cc14
-rw-r--r--tensorflow/core/kernels/sparse_softmax_op.cc2
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h281
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h269
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc664
-rw-r--r--tensorflow/core/platform/default/build_config.bzl5
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc2
32 files changed, 122 insertions, 1323 deletions
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index c0ea242044..1b8198dba8 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -1884,7 +1884,8 @@ TEST_F(OpTest, DynamicStitch) {
for (int i = 0; i < n; ++i) {
TensorShape shape(index_dims[i]);
Tensor t = test::AsTensor<int32>(
- gtl::ArraySlice<int32>(indices, pos, shape.num_elements()), shape);
+ gtl::ArraySlice<int32>(indices).subspan(pos, shape.num_elements()),
+ shape);
builder.Input(t);
pos += t.NumElements();
}
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 67fb56510c..ff3de75ad2 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -56,9 +56,10 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int n_dims = xla::ShapeUtil::Rank(a_shape);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - 2);
+ auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
xla::XlaOp l = xla::ZerosLike(a);
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 8b5beba383..24e5dbbc6d 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - n_minor_dims);
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
// Prepends 0s in the major dim
std::vector<int64> padded_start(n_dims, 0);
@@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - sizes.size());
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
auto padded_starts = PrependZerosInMajorDims(x, starts);
auto padded_sizes = ConcatVectors(major_dims, sizes);
return xla::DynamicSlice(x, padded_starts, padded_sizes);
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index 38e440c68d..7f90d6c197 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -39,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ tensorflow::gtl::ArraySlice<int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
@@ -66,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ tensorflow::gtl::ArraySlice<int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
xla::XlaOp indicator;
diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc
index 7c4efdee48..93522d2ca8 100644
--- a/tensorflow/compiler/xla/index_util_test.cc
+++ b/tensorflow/compiler/xla/index_util_test.cc
@@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) {
TEST(IndexUtilTest, BumpIndices2x2) {
auto shape = ShapeUtil::MakeShape(S32, {2, 2});
std::vector<int64> indices = {0, 0};
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(0, 1));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 0));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 1));
- EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3dd0abee79..2fc3613650 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -366,7 +366,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
do {
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
- } while (IndexUtil::BumpIndices(dest_shape, &index));
+ } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
}
} // namespace
@@ -1192,7 +1192,7 @@ void LiteralBase::EachCellAsString(
shape(), /*linear_index=*/0);
do {
per_cell(indices, GetAsString(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
namespace {
@@ -1392,7 +1392,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
elements.push_back(std::move(*new_element));
}
auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(&elements);
+ *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
return std::move(converted);
}
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 8370043da1..c6ef99db0f 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -992,7 +992,7 @@ void LiteralBase::EachCell(
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
do {
per_cell(indices, Get<NativeT>(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 14ad08a681..f6ce69eaee 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -353,11 +353,11 @@ class NearComparator {
// bound is exceeded and vice versa.
if (is_abs_mismatch) {
num_abs_mismatches_++;
- UpdateErrorBucket(rel_error, &rel_error_buckets_);
+ UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
}
if (is_rel_mismatch) {
num_rel_mismatches_++;
- UpdateErrorBucket(abs_error, &abs_error_buckets_);
+ UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
}
UpdateAbsValueBucket(actual, is_mismatch);
@@ -579,40 +579,41 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ auto index = absl::MakeSpan(multi_index);
Status result;
switch (expected.shape().element_type()) {
case PRED:
- result = Equal<bool>(expected, actual, &multi_index, 0);
+ result = Equal<bool>(expected, actual, index, 0);
break;
case U8:
- result = Equal<uint8>(expected, actual, &multi_index, 0);
+ result = Equal<uint8>(expected, actual, index, 0);
break;
case S32:
- result = Equal<int32>(expected, actual, &multi_index, 0);
+ result = Equal<int32>(expected, actual, index, 0);
break;
case S64:
- result = Equal<int64>(expected, actual, &multi_index, 0);
+ result = Equal<int64>(expected, actual, index, 0);
break;
case U32:
- result = Equal<uint32>(expected, actual, &multi_index, 0);
+ result = Equal<uint32>(expected, actual, index, 0);
break;
case U64:
- result = Equal<uint64>(expected, actual, &multi_index, 0);
+ result = Equal<uint64>(expected, actual, index, 0);
break;
case BF16:
- result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ result = Equal<bfloat16>(expected, actual, index, 0);
break;
case F16:
- result = Equal<half>(expected, actual, &multi_index, 0);
+ result = Equal<half>(expected, actual, index, 0);
break;
case F32:
- result = Equal<float>(expected, actual, &multi_index, 0);
+ result = Equal<float>(expected, actual, index, 0);
break;
case F64:
- result = Equal<double>(expected, actual, &multi_index, 0);
+ result = Equal<double>(expected, actual, index, 0);
break;
case C64:
- result = Equal<complex64>(expected, actual, &multi_index, 0);
+ result = Equal<complex64>(expected, actual, index, 0);
break;
case TUPLE: {
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index e08a9d6e41..5d7d4dbb36 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -1561,7 +1561,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) {
));
- Literal literal = Literal::MoveIntoTuple(&elements);
+ Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 08773693fb..bf2efc4d14 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -284,8 +284,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
arguments));
- TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &owning_buffers));
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer result,
+ CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers)));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 4edcb05f83..f1cdd03404 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1130,7 +1130,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
}
cnt : {}
- } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
+ } while (IndexUtil::BumpIndices(window_shape,
+ absl::MakeSpan(rhs_spatial_index)));
return static_cast<ReturnT>(result_val);
};
@@ -1854,7 +1855,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
}
});
- } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+ } while (
+ IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index)));
parent_->evaluated_[select_and_scatter] = std::move(result);
return Status::OK();
@@ -2624,7 +2626,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!out_of_bound) {
f(base_index);
}
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ } while (
+ IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index)));
}
template <typename IndexT>
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8a497e6edf..b747a4ea5f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -167,11 +167,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.called_computation_ids_size();
{
const auto reduce_operands = all_operands();
- tensorflow::gtl::ArraySlice<HloInstruction*> inputs(
- reduce_operands, 0, reduce_operands.size() / 2);
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values(
- reduce_operands, reduce_operands.size() / 2,
- reduce_operands.size());
+ auto inputs = absl::MakeSpan(reduce_operands)
+ .subspan(0, reduce_operands.size() / 2);
+ auto init_values =
+ absl::MakeSpan(reduce_operands)
+ .subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction =
CreateReduce(proto.shape(), inputs, init_values,
std::vector<int64>(proto.dimensions().begin(),
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index c2d551fb25..4fe5144aca 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -404,14 +404,12 @@ class HloReduceInstruction : public HloInstruction {
// Returns the input tensors to be reduced.
tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0,
- input_count());
+ return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands(), input_count(), operand_count());
+ return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
private:
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index eae4508b24..b93e4f24f6 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -997,11 +997,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
- operands.size() / 2),
+ tensorflow::gtl::ArraySlice<HloInstruction*>(operands).subspan(
+ 0, operands.size() / 2),
/*init_values=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands, operands.size() / 2, operands.size()),
+ tensorflow::gtl::ArraySlice<HloInstruction*>(operands).subspan(
+ operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 6971220022..36e713d1ac 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
- Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
- multidim_, common_factors[k].second,
+ Index(tensorflow::gtl::ArraySlice<llvm::Value*>(multidim_).subspan(
+ common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second),
index_type_)
- .Linearize(
- tensorflow::gtl::ArraySlice<int64>(
- AsInt64Slice(output_shape.dimensions()),
- common_factors[k].second,
- common_factors[k + 1].second - common_factors[k].second),
- builder);
+ .Linearize(AsInt64Slice(output_shape.dimensions())
+ .subspan(common_factors[k].second,
+ common_factors[k + 1].second -
+ common_factors[k].second),
+ builder);
// Delinearizes logical_linear_index for the source array in row-major
// collapsed order. The first rank-1 indices are the remainder of the
// linear index by each dimension size.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index f5217c5a11..45427bba25 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1872,8 +1872,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
int64 num_reduced_args = arg_shapes.size() / 2;
- tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
- num_reduced_args);
+ auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
// Check that all of the reduced tensors have the same dimensions. The element
// types may be different.
for (int64 i = 1; i < num_reduced_args; ++i) {
@@ -1897,8 +1896,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
- tensorflow::gtl::ArraySlice<const Shape*> init_values(
- arg_shapes, num_reduced_args, arg_shapes.size());
+ auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
std::vector<PrimitiveType> element_types;
for (const Shape* arg : reduced_args) {
element_types.push_back(arg->element_type());
diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h
index 70fab3bea5..7291705b61 100644
--- a/tensorflow/compiler/xla/sparse_index_array.h
+++ b/tensorflow/compiler/xla/sparse_index_array.h
@@ -96,7 +96,9 @@ class SparseIndexArray {
int64 max_indices() const { return max_indices_; }
// Returns a pointer to the int64 array that holds the sparse indices.
- tensorflow::gtl::MutableArraySlice<int64> mutable_data() { return &indices_; }
+ tensorflow::gtl::MutableArraySlice<int64> mutable_data() {
+ return absl::MakeSpan(indices_);
+ }
tensorflow::gtl::ArraySlice<int64> data() const { return indices_; }
// Sorts this sparse index array along with the set of corresponding values.
diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc
index 7377f88958..e54057c400 100644
--- a/tensorflow/compiler/xla/sparse_index_array_test.cc
+++ b/tensorflow/compiler/xla/sparse_index_array_test.cc
@@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) {
std::vector<double> values = {
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
};
- a.SortWithValues<double>(&values);
+ a.SortWithValues<double>(absl::MakeSpan(values));
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
6, 7, 6, 7, 8}));
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 84bae05c38..15a9d55bfe 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -113,7 +113,7 @@ class FusionTest : public HloTestBase {
hlos[0] = builder.AddInstruction(std::move(root_hlo));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(
- ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
+ ArraySlice<HloInstruction*>(hlos).subspan(0, Arity + 1),
HloInstruction::FusionKind::kLoop);
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 16b77e965d..b0a324f6fc 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -96,8 +96,8 @@ class MultiOutputFusionTest : public HloTestBase {
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub, add2}, 0, 2)));
+ auto tuple =
+ computation->AddInstruction(HloInstruction::CreateTuple({sub, add2}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0));
auto gte1 = computation->AddInstruction(
@@ -159,8 +159,8 @@ class MultiOutputFusionTest : public HloTestBase {
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub_U8, add}, 0, 2)));
+ auto tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple({sub_U8, add}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0));
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 62f486369f..b3343b1506 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -293,9 +293,10 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
-template <template <typename...> class C, typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- C<T> input) {
+template <typename Container>
+std::vector<typename Container::value_type> Permute(
+ tensorflow::gtl::ArraySlice<int64> permutation, const Container& input) {
+ using T = typename Container::value_type;
tensorflow::gtl::ArraySlice<T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
@@ -305,17 +306,6 @@ std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
return output;
}
-// Override of the above that works around compile failures with gcc 7.1.1.
-// For details see https://github.com/tensorflow/tensorflow/issues/10843
-// Hide this workaround from MSVC as it causes ambiguous error.
-#ifndef _MSC_VER
-template <typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- const std::vector<T>& input) {
- return Permute<std::vector, T>(permutation, input);
-}
-#endif
-
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
tensorflow::gtl::ArraySlice<int64> input_permutation);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 7648db9c12..608b08efba 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2471,6 +2471,7 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
+ includes = ["../../external/com_google_absl"],
deps = [
":lib",
":lib_internal",
@@ -2559,6 +2560,7 @@ cc_header_only_library(
# ABSL headers get dropped, so we add them back here.
"@com_google_absl//absl/strings",
],
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2568,6 +2570,7 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -3219,7 +3222,6 @@ tf_cc_tests(
"lib/core/status_test.cc",
"lib/core/stringpiece_test.cc",
"lib/core/threadpool_test.cc",
- "lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/compactptrset_test.cc",
"lib/gtl/edit_distance_test.cc",
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 21c6940b62..20a07d86a2 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -432,9 +432,9 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
- TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
- &batch_size_dim, &input_spatial_dims,
- &input_depth_dim, c));
+ TF_RETURN_IF_ERROR(DimensionsFromShape(
+ conv_input_shape, data_format, &batch_size_dim,
+ absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
DimensionHandle output_depth_dim = c->Dim(
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 654d99301a..663bff3657 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -89,9 +89,9 @@ class BaseCandidateSamplerOp : public OpKernel {
// Pick sampled candidates.
auto local_gen = generator_.ReserveSamples32(samples32);
random::SimplePhilox random(&local_gen);
- sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate,
- &sampled_expected_count,
- true_candidate, &true_expected_count);
+ sampler_->SampleBatchGetExpectedCount(&random, unique_, sampled_candidate,
+ sampled_expected_count,
+ true_candidate, true_expected_count);
if (sampler_->NeedsUpdates()) {
sampler_->Update(true_candidate);
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
index 9020121169..3d49af7cb1 100644
--- a/tensorflow/core/kernels/range_sampler_test.cc
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -45,7 +45,7 @@ class RangeSamplerTest : public ::testing::Test {
// Using a fixed random seed to make the test deterministic.
random::PhiloxRandom philox(123, 17);
random::SimplePhilox rnd(&philox);
- sampler_->SampleBatch(&rnd, false, &a);
+ sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
for (int i = 0; i < num_samples; i++) {
int64 val = a[i];
ASSERT_GE(val, 0);
@@ -251,8 +251,9 @@ TEST_F(RangeSamplerTest, All) {
extras[0] = 0;
extras[1] = batch_size - 1;
sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
- false, &batch, &batch_expected, extras,
- &extras_expected);
+ false, absl::MakeSpan(batch),
+ absl::MakeSpan(batch_expected), extras,
+ absl::MakeSpan(extras_expected));
for (int i = 0; i < batch_size; i++) {
EXPECT_EQ(i, batch[i]);
EXPECT_EQ(1, batch_expected[i]);
@@ -281,17 +282,18 @@ TEST_F(RangeSamplerTest, Unique) {
std::vector<float> expected(range);
// Sample one batch and get the expected counts of all values
- sampler_->SampleBatchGetExpectedCount(
- &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
+ MutableArraySlice<float>(), all_values,
+ absl::MakeSpan(expected));
// Check that all elements are unique
std::set<int64> s(batch.begin(), batch.end());
CHECK_EQ(batch_size, s.size());
for (int trial = 0; trial < num_batches; trial++) {
std::vector<float> trial_expected(range);
- sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
- MutableArraySlice<float>(),
- all_values, &trial_expected);
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ all_values, absl::MakeSpan(trial_expected));
for (int i = 0; i < range; i++) {
EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
}
@@ -318,8 +320,8 @@ TEST_F(RangeSamplerTest, Avoid) {
// We expect to pick all elements of [0, 100) except the avoided two.
sampler_->SampleBatchGetExpectedCountAvoid(
- &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
- MutableArraySlice<float>(), avoided);
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
int sum = 0;
for (auto val : batch) {
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index f893d4e945..0428909145 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -269,7 +269,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
// Group by all but last dimension, create a set of group values, and add set
// size to output.
- VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1);
+ VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
std::set<T> group_set;
for (const auto& group : set_st.group(group_ix)) {
PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
@@ -500,8 +500,8 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
std::vector<int64> group_indices;
int64 num_elements;
@@ -621,11 +621,11 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set1_grouper = set1_st.group(
- VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1));
+ auto set1_grouper =
+ set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
auto set1_group_it = set1_grouper.begin();
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
// Group by rows, and iterate over rows of both sets in parallel, creating a
diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc
index dc3119bba4..37664fe8df 100644
--- a/tensorflow/core/kernels/sparse_softmax_op.cc
+++ b/tensorflow/core/kernels/sparse_softmax_op.cc
@@ -90,7 +90,7 @@ class SparseSoftmaxOp : public OpKernel {
// { 0, ..., rank-1 }.
const ArraySlice<int64> kReorderDims(dims);
// All but the last dim -- the class dimension to be max-reduced along.
- const ArraySlice<int64> kGroupByDims(kReorderDims, 0, rank - 1);
+ const ArraySlice<int64> kGroupByDims = kReorderDims.subspan(0, rank - 1);
st.Reorder<T>(kReorderDims);
int count = 0;
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
index b773a65569..8f47faf89e 100644
--- a/tensorflow/core/lib/gtl/array_slice.h
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -13,293 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An ArraySlice<T> represents an immutable array of elements of type
-// T. It has a length "length", and a base pointer "ptr", and the
-// array it represents contains the elements "ptr[0] .. ptr[len-1]".
-// The backing store for the array is *not* owned by the ArraySlice
-// object, and clients must arrange for the backing store to remain
-// live while the ArraySlice object is in use.
-//
-// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
-// array elements of type T.
-//
-// Implicit conversion operations are provided from types such as
-// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
-// objects constructed from types in this way may be invalidated by
-// any operations that mutate the underlying vector.
-//
-// One common use for ArraySlice is when passing arguments to a
-// routine where you want to be able to accept a variety of array
-// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
-// etc.). The usual approach here is to have the client explicitly
-// pass in a pointer and a length, as in:
-//
-// void MyRoutine(const int* elems, int N) {
-// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
-// }
-//
-// Unfortunately, this leads to ugly and error-prone code at the call site:
-//
-// std::vector<int> my_vector;
-// MyRoutine(vector_as_array(&my_vector), my_vector.size());
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
-//
-// int my_array[10];
-// MyRoutine(my_array, 10);
-//
-// Instead, you can use an ArraySlice as the argument to the routine:
-//
-// void MyRoutine(ArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
-// }
-//
-// This makes the call sites cleaner, for the most part:
-//
-// std::vector<int> my_vector;
-// MyRoutine(my_vector);
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector);
-//
-// int my_array[10];
-// MyRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
-//
-// MutableArraySlice<T> represents a mutable array of elements, and, like
-// ArraySlice, does not own the backing store. The implicit constructors it
-// provides allow functions not to worry about whether their mutable arguments
-// refer to vectors, arrays, proto2::RepeatedFields, etc.:
-//
-// void MyMutatingRoutine(MutableArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
-// }
-//
-// std::vector<int> my_vector;
-// MyMutatingRoutine(&my_vector);
-//
-// int my_array[10];
-// MyMutatingRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
-//
-// MyProto my_proto;
-// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
-// MyMutatingRoutine(my_proto.mutable_value());
-
#ifndef TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
#define TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
-#include <initializer_list>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "absl/types/span.h"
+// TODO(timshen): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
namespace gtl {
template <typename T>
-class ArraySlice {
- private:
- typedef array_slice_internal::ArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- ArraySlice() : impl_(nullptr, 0) {}
- ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- template <size_t N>
- ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- // The constructor for any class supplying 'data() const' that returns either
- // const T* or a less const-qualified version of it, and 'some_integral_type
- // size() const'. proto2::RepeatedField<T>, string and (since C++11)
- // std::vector<T,A> and std::array<T, N> are examples of this. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- ArraySlice(const V& v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Implicitly constructs an ArraySlice from an initializer list. This makes it
- // possible to pass a brace-enclosed initializer list to a function expecting
- // an ArraySlice:
- // void Process(ArraySlice<int> x);
- // Process({1, 2, 3});
- // The data referenced by the initializer_list must outlive this
- // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
- // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
- // reference data that is no longer valid.
- ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
- : impl_(v.begin(), v.size()) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- ArraySlice(const ArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- const_pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- const_reference operator[](size_type i) const { return impl_[i]; }
- const_reference at(size_type i) const { return impl_.at(i); }
- const_reference front() const { return impl_.front(); }
- const_reference back() const { return impl_.back(); }
-
- const_iterator begin() const { return impl_.begin(); }
- const_iterator end() const { return impl_.end(); }
- const_reverse_iterator rbegin() const { return impl_.rbegin(); }
- const_reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
-
- // These relational operators have the same semantics as the
- // std::vector<T> relational operators: they do deep (element-wise)
- // comparisons. Array slices are equal iff their size is the same
- // and all their elements are equal.
- bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
- bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
-
- private:
- Impl impl_;
-};
-
-// Mutable version of ArraySlice, which allows the clients to mutate the
-// underlying data. It is implicitly convertible to ArraySlice since it provides
-// the data() and size() methods with correct signatures. When a
-// MutableArraySlice is created from a pointer to a container (as opposed to raw
-// memory pointer), the pointer must not be null.
-//
-// A note on const-ness: "mutable" here refers to the mutability of the
-// underlying data, not of the slice itself. It is perfectly reasonable to have
-// a variable of type "const MutableArraySlice<T>"; this means that the bounds
-// of the view on the array cannot be changed, but the underlying data in the
-// array still may be modified. This is akin to a "T* const" pointer, as opposed
-// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
-template <typename T>
-class MutableArraySlice {
- private:
- typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- MutableArraySlice() : impl_(nullptr, 0) {}
- MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- template <size_t N>
- MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- MutableArraySlice(
- InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
- // (the former is called if both exist), and 'some_integral_type size()
- // const'. proto2::RepeatedField is an example of this. Also supports string
- // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- MutableArraySlice(V* v) // NOLINT(runtime/explicit)
- : impl_(v) {}
+using ArraySlice = absl::Span<const T>;
- // Substring of another MutableArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- // Accessors.
- pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- reference operator[](size_type i) const { return impl_[i]; }
- reference at(size_type i) const { return impl_.at(i); }
- reference front() const { return impl_.front(); }
- reference back() const { return impl_.back(); }
-
- iterator begin() const { return impl_.begin(); }
- iterator end() const { return impl_.end(); }
- reverse_iterator rbegin() const { return impl_.rbegin(); }
- reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
-
- bool operator==(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) == other;
- }
- bool operator!=(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) != other;
- }
-
- private:
- Impl impl_;
-};
-
-template <typename T>
-const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
template <typename T>
-const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+using MutableArraySlice = absl::Span<T>;
} // namespace gtl
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
deleted file mode 100644
index 689dd8a646..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_internal.h
+++ /dev/null
@@ -1,269 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
-// array_slice.h.
-
-// Helper functions and templates for ArraySlice.
-
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-
-#include <stddef.h>
-#include <algorithm>
-#include <iterator>
-#include <memory>
-#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace array_slice_internal {
-
-// Template logic for generic constructors.
-
-// Wrappers whose Get() delegates to the appropriate method of a container, and
-// is defined when this method exists. Delegates to the const method if C is a
-// const type.
-struct Data {
- template <typename C>
- static decltype(std::declval<C>().data()) Get(C* v) {
- return v->data();
- }
-};
-
-struct MutableData {
- template <typename C>
- static decltype(std::declval<C>().mutable_data()) Get(C* v) {
- return v->mutable_data();
- }
-};
-
-struct Size {
- template <typename C>
- static decltype(std::declval<C>().size()) Get(C* v) {
- return v->size();
- }
-};
-
-struct MutableStringData {
- // Defined only for string.
- static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
-};
-
-// Checks whether M::Get(C*) is defined and has a return type R such that
-// Checker::valid<R>()==true.
-template <typename M, typename Checker, typename C>
-struct HasGetHelper : public M {
- private:
- struct None {};
- // M::Get is selected when it is viable. Get(...) is selected otherwise.
- using M::Get;
- static None Get(...);
-
- public:
- static constexpr bool HasGet() {
- using Result = decltype(Get(std::declval<C*>()));
- return !std::is_same<Result, None>() && Checker::template valid<Result>();
- }
-};
-
-// Defines HasGet() for a particular method, container, and checker. If
-// HasGet()==true, provides Get() that delegates to the method.
-template <typename M, typename Checker, typename C,
- bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
-struct Wrapper {
- static constexpr bool HasGet() { return false; }
-};
-
-template <typename M, typename Checker, typename C>
-struct Wrapper<M, Checker, C, true> {
- static constexpr bool HasGet() { return true; }
- static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
-};
-
-// Type checker for a method returning an integral value.
-struct SizeChecker {
- template <typename R>
- static constexpr bool valid() {
- return std::is_integral<R>::value;
- }
-};
-
-// Type checker for a method returning either a pointer to T or a less const
-// version of that.
-template <typename T>
-struct DataChecker {
- // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
- // but
- // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
- // use
- // the fact that U** is convertible to Q* const* if and only if Q is the same
- // type or a more cv-qualified version of U.
- template <typename R>
- static constexpr bool valid() {
- return std::is_convertible<R*, T* const*>::value;
- }
-};
-
-// Aliases to A if A::HasGet()==true, or to B otherwise.
-template <typename A, typename B>
-using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
-
-// Wraps C::data() const, returning a pointer to const data.
-template <typename T, typename C>
-using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
-
-// Wraps a method returning a pointer to mutable data. Prefers data() over
-// mutable_data(), and handles strings when T==char. If data() returns a pointer
-// to mutable data, it is most likely overloaded, but may also be a single
-// method 'T* C::data() const' in a non-STL-compliant container.
-template <typename T, typename C>
-using ContainerMutableData =
- FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
- FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
- Wrapper<MutableStringData, DataChecker<T>, C>>>;
-
-// Wraps C::size() const.
-template <typename C>
-using ContainerSize = Wrapper<Size, SizeChecker, const C>;
-
-// Implementation class for ArraySlice and MutableArraySlice. In the case of
-// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
-// mutable type.
-template <typename T>
-class ArraySliceImplBase {
- public:
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
- typedef std::reverse_iterator<iterator> reverse_iterator;
- typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
- typedef size_t size_type;
- typedef ptrdiff_t difference_type;
-
- static const size_type npos = static_cast<size_type>(-1);
-
- ArraySliceImplBase(pointer array, size_type length)
- : ptr_(array), length_(length) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
- : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
-
- // Some of the const methods below return pointers and references to mutable
- // data. This is only the case in this internal class; ArraySlice and
- // MutableArraySlice provide deep-constness.
-
- pointer data() const { return ptr_; }
- size_type size() const { return length_; }
-
- void clear() {
- ptr_ = nullptr;
- length_ = 0;
- }
-
- reference operator[](size_type i) const { return ptr_[i]; }
- reference at(size_type i) const {
- DCHECK_LT(i, length_);
- return ptr_[i];
- }
- reference front() const {
- DCHECK_GT(length_, 0);
- return ptr_[0];
- }
- reference back() const {
- DCHECK_GT(length_, 0);
- return ptr_[length_ - 1];
- }
-
- void remove_prefix(size_type n) {
- DCHECK_GE(length_, n);
- ptr_ += n;
- length_ -= n;
- }
- void remove_suffix(size_type n) {
- DCHECK_GE(length_, n);
- length_ -= n;
- }
-
- iterator begin() const { return ptr_; }
- iterator end() const { return ptr_ + length_; }
- reverse_iterator rbegin() const { return reverse_iterator(end()); }
- reverse_iterator rend() const { return reverse_iterator(begin()); }
-
- bool operator==(const ArraySliceImplBase& other) const {
- if (size() != other.size()) return false;
- if (data() == other.data()) return true;
- return std::equal(data(), data() + size(), other.data());
- }
- bool operator!=(const ArraySliceImplBase& other) const {
- return !(*this == other);
- }
-
- private:
- pointer ptr_;
- size_type length_;
-};
-
-template <typename T>
-class ArraySliceImpl : public ArraySliceImplBase<const T> {
- public:
- using ArraySliceImplBase<const T>::ArraySliceImplBase;
-
- // Defined iff the data and size accessors for the container C have been
- // defined.
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- // Constructs from a container when EnableIfConvertibleFrom is
- // defined. std::addressof handles types with overloaded operator&.
- template <typename C>
- explicit ArraySliceImpl(const C& v)
- : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
- ContainerSize<C>::Get(std::addressof(v))) {}
-};
-
-template <typename T>
-class MutableArraySliceImpl : public ArraySliceImplBase<T> {
- public:
- using ArraySliceImplBase<T>::ArraySliceImplBase;
-
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- template <typename C>
- explicit MutableArraySliceImpl(C* v)
- : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
- ContainerSize<C>::Get(v)) {}
-};
-
-} // namespace array_slice_internal
-} // namespace gtl
-} // namespace tensorflow
-
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
deleted file mode 100644
index c798a488cb..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_test.cc
+++ /dev/null
@@ -1,664 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/array_slice.h"
-
-#include <algorithm>
-#include <array>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace {
-
-typedef ArraySlice<int> IntSlice;
-typedef ArraySlice<char> CharSlice;
-typedef MutableArraySlice<int> MutableIntSlice;
-typedef MutableArraySlice<char> MutableCharSlice;
-typedef std::vector<int> IntVec;
-
-// Append 0..len-1 to *v
-template <typename Vector>
-static void Fill(Vector* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
- IntSlice other; // To test the assignment return value.
- IntSlice v = other = vorig;
- const int len = vec.size();
- EXPECT_EQ(v.size(), vec.size());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(v[i], vec[i]);
- EXPECT_EQ(v.at(i), vec[i]);
- }
- EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
-
- int counter = 0;
- for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.remove_suffix(1);
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- if (len > 1) {
- v.remove_prefix(1);
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i + 1, v[i]);
- }
- }
- }
-}
-
-// The element access test that is applicable both when MutableArraySlice is
-// const and when it's not.
-template <class V>
-void MutableTestHelperTemplated(V v, int* ptr, const int len) {
- CHECK_EQ(v.size(), len);
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(ptr + i, &v[i]);
- EXPECT_EQ(ptr + i, &v.at(i));
- }
- EXPECT_EQ(ptr, v.begin());
- EXPECT_EQ(ptr + len, v.end());
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
-
- if (len > 0) {
- EXPECT_EQ(ptr, &v.front());
- EXPECT_EQ(ptr + len - 1, &v.back());
- EXPECT_EQ(ptr + len - 1, &*v.rbegin());
- EXPECT_EQ(ptr, &*(v.rend() - 1));
- }
-}
-
-static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
- const int len) {
- // Test the data accessors both when the MutableArraySlice is declared const,
- // and when it is not.
- MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
- MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
-
- MutableIntSlice other; // To test the assignment return value.
- MutableIntSlice v = other = vorig;
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- // Test that elements are assignable.
- v[0] = 1;
- v.front() = 2;
- v.back() = 5;
- *v.data() = 4;
- std::fill(v.begin(), v.end(), 5);
- std::fill(v.rbegin(), v.rend(), 6);
- // Test size-changing methods.
- v.remove_suffix(1);
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i, &v[i]);
- }
- if (len > 1) {
- v.remove_prefix(1);
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i + 1, &v[i]);
- }
- }
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
- EXPECT_EQ(v.size(), vec.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(v[i], vec[i]);
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
- TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
-}
-
-static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-
-static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-// A struct supplying the data(), mutable_data() and size() methods, just like
-// e.g. proto2::RepeatedField.
-struct RepeatedField {
- std::vector<int> storage;
- const int* data() const { return storage.data(); }
- int* mutable_data() { return storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying the data() (both mutable and const versions) and
-// size(). It also supplies mutable_data() but we test that data() is selected
-// instead.
-struct ContainerWithOverloads {
- std::vector<int> storage;
- std::vector<int> wrong_storage;
- const int* data() const { return storage.data(); }
- int* data() { return storage.data(); }
- // MutableArraySlice should not call mutable_data(), preferring data()
- // instead.
- int* mutable_data() { return wrong_storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying data() and size() methods.
-struct ContainerWithShallowConstData {
- std::vector<int> storage;
- int* data() const { return const_cast<int*>(storage.data()); }
- int size() const { return storage.size(); }
-};
-
-TEST(IntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- TestHelper(IntSlice(vec), vec);
- TestHelper(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, WithPosAndLen) {
- IntVec vec;
- Fill(&vec, 20);
- for (size_t len = 0; len < vec.size(); len++) {
- IntVec subvec(vec.begin(), vec.begin() + len);
- TestImplicitConversion(IntSlice(vec, 0, len), subvec);
- TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
- }
- EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
- EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
- TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
-}
-
-TEST(IntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice v(vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec, bvec;
- Fill(&avec, l1);
- Fill(&bvec, l2, 100);
- IntSlice a(avec), b(bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(IntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice slice;
- slice = vec;
- TestImplicitConversion(vec, vec);
- TestImplicitConversion(slice, vec);
- TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- IntVec vec;
- Fill(&vec, len);
- IntSlice v = inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(inline_vec, vec);
- }
-}
-
-TEST(IntSlice, StaticArrayConversion) {
- int array[20];
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(array));
- std::copy(vec.begin(), vec.end(), array);
- IntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, vec);
-}
-
-TEST(IntSlice, StdArrayConversion) {
- std::array<int, 20> array;
- IntVec vec;
- Fill(&vec, array.size());
- std::copy(vec.begin(), vec.end(), array.begin());
-
- // Check assignment.
- {
- IntSlice v = array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- IntSlice v = {array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(array, vec);
-}
-
-// Values according to the Fill function.
-static const int test_const_array[] = {0, 1, 2};
-
-TEST(IntSlice, ConstStaticArrayConversion) {
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(test_const_array));
- IntSlice v = test_const_array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(test_const_array, vec);
-}
-
-TEST(IntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- IntVec vec;
- Fill(&vec, 20);
- repeated_field.storage = vec;
- IntSlice v = repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(repeated_field, vec);
-}
-
-TEST(IntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, MutableIntSliceConversion) {
- IntVec vec(20);
- IntSlice slice = MutableIntSlice(&vec);
- EXPECT_EQ(vec.size(), slice.size());
- EXPECT_EQ(vec.data(), slice.data());
-}
-
-TEST(IntSlice, Equality) {
- IntVec vec1(20);
- IntVec vec2(20);
- // These two slices are from different vectors, but have the same
- // size and have the same elements (right now). They should
- // compare equal.
- const IntSlice from1(vec1);
- const IntSlice from2(vec2);
- EXPECT_EQ(from1, from1);
- EXPECT_EQ(from1, from2);
-
- // This verifies that MutableArraySlices can be compared freely with
- // ArraySlices.
- const MutableIntSlice mutable_from1(&vec1);
- const MutableIntSlice mutable_from2(&vec2);
- EXPECT_EQ(from1, mutable_from1);
- EXPECT_EQ(mutable_from1, from1);
- EXPECT_EQ(mutable_from1, mutable_from2);
- EXPECT_EQ(mutable_from2, mutable_from1);
-
- // With a different size, the array slices should not be equal.
- EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
-
- // With different contents, the array slices should not be equal.
- ++vec2.back();
- EXPECT_NE(from1, from2);
-}
-
-// Compile-asserts that the argument has the expected type.
-template <typename Expected, typename T>
-void CheckType(const T& value) {
- ::testing::StaticAssertTypeEq<Expected, T>();
-}
-
-TEST(IntSlice, ExposesContainerTypesAndConsts) {
- IntSlice slice;
- const IntSlice const_slice;
- CheckType<IntSlice::iterator>(slice.begin());
- CheckType<IntSlice::const_iterator>(const_slice.end());
- CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
- CheckType<IntSlice::reverse_iterator>(slice.rend());
- ::testing::StaticAssertTypeEq<int, IntSlice::value_type>();
- ::testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
- EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
-}
-
-void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
-
-void TestRange(IntSlice slice, int from, int to) {
- ASSERT_EQ(to - from + 1, slice.size());
- for (size_t i = 0; i < slice.size(); ++i) {
- EXPECT_EQ(from + i, slice[i]);
- }
-}
-
-TEST(IntSlice, InitializerListConversion) {
- TestEmpty({});
- TestRange({1}, 1, 1);
- TestRange({10, 11, 12, 13}, 10, 13);
-}
-
-TEST(CharSlice, StringConversion) {
- IntVec vec;
- Fill(&vec, 20);
- string str(vec.begin(), vec.end());
- CharSlice v = str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(str, vec);
-}
-
-TEST(IntPtrSlice, ConstConversion) {
- int one = 1;
- int two = 2;
- std::vector<int*> vec;
- vec.push_back(&one);
- vec.push_back(&two);
- ArraySlice<const int*> v = vec;
- ASSERT_EQ(2, v.size());
- EXPECT_EQ(&one, v[0]);
- EXPECT_EQ(&two, v[1]);
-}
-
-TEST(MutableIntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
- MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
- }
-}
-
-TEST(MutableIntSlice, WithPosAndLen) {
- IntVec vec(20);
- for (size_t len = 0; len < vec.size(); len++) {
- TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
- TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
- vec.data(), len);
- }
- EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
- EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
- TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
- vec.data(), vec.size());
-}
-
-TEST(MutableIntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice v(&vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(MutableIntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec(l1), bvec(l2);
- MutableIntSlice a(&avec), b(&bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(&avec[i], &b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(&bvec[i], &a[i]);
- }
- }
- }
-}
-
-TEST(MutableIntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice slice;
- slice = &vec;
- TestImplicitConversion(&vec, vec.data(), len);
- TestImplicitConversion(slice, vec.data(), len);
- TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
- len);
- }
-}
-
-TEST(MutableIntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- MutableIntSlice v = &inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&inline_vec, inline_vec.data(), inline_vec.size());
- }
-}
-
-TEST(MutableIntSlice, StaticArrayConversion) {
- int array[20];
- MutableIntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
-}
-
-TEST(MutableIntSlice, StdArrayConversion) {
- std::array<int, 20> array;
-
- // Check assignment.
- {
- MutableIntSlice v = &array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- MutableIntSlice v = {&array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(&array, &array[0], array.size());
-}
-
-TEST(MutableIntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- Fill(&repeated_field.storage, 20);
- MutableIntSlice v = &repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
- repeated_field.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, TypedefsAndConstants) {
- ::testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
- ::testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
- ::testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
- ::testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
-
- EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences_Const) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- const MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-bool TestMutableOverload(MutableIntSlice slice) { return false; }
-
-bool TestMutableOverload(MutableCharSlice slice) { return true; }
-
-TEST(MutableCharSlice, StringConversion) {
- for (int len = 0; len < 20; len++) {
- string str(len, '\0');
- MutableCharSlice v = &str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(v, str.data(), str.size());
- }
- // Verify that only the correct overload is feasible. Note that this would
- // fail if the string ctor was declared simply as MutableArraySlice(string*),
- // since in that case both overloads would be feasible.
- string str;
- EXPECT_TRUE(TestMutableOverload(&str));
-
- // Avoid warning "unused function 'TestMutableOverload'"
- int a[1];
- EXPECT_FALSE(TestMutableOverload(a));
-}
-
-} // namespace
-} // namespace gtl
-} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6a4ff9a1cb..0411a8c4f9 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -623,7 +623,10 @@ def tf_additional_lib_defines():
def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
- return ["@com_google_absl//absl/base:base"] + if_static(
+ return [
+ "@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/types:span",
+ ] + if_static(
["@nsync//:nsync_cpp"],
["@nsync//:nsync_headers"],
) + select({
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 55408ab9ab..207f22c931 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3894,7 +3894,7 @@ bool CudnnSupport::DoDepthConcatenate(
for (size_t i = 0; i < input_data.size(); ++i) {
const auto& dimensions = input_dimensions[i];
tmp.resize(dimensions.ElementCount());
- stream->ThenMemcpyD2H<float>(*input_data[i], &tmp);
+ stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
port::Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;