aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-30 16:03:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 16:07:27 -0700
commit6f879f891abe2e267c5cf512d034d7c3641cfdb0 (patch)
tree33dfda2aa13bdec06d3aa330dd5816441d449fa7 /tensorflow/compiler/xla/service
parent5d5591fbd4624ff7e50f305464667315f2d41ebb (diff)
[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc10
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/backend.cc10
-rw-r--r--tensorflow/compiler/xla/service/backend.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc6
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h4
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h17
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc24
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h32
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.h3
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc2
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/executable.cc7
-rw-r--r--tensorflow/compiler/xla/service/executable.h13
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc5
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc5
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h30
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc68
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h62
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h39
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc86
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h100
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc60
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h60
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc178
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h205
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc36
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h12
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.h4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc16
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc20
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h31
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h3
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/local_service.h2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc3
-rw-r--r--tensorflow/compiler/xla/service/service.cc33
-rw-r--r--tensorflow/compiler/xla/service/service.h24
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc89
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h64
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/while_util.h2
133 files changed, 836 insertions, 1044 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 196865f333..a7a0044308 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -449,8 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
Status AlgebraicSimplifierVisitor::HandleConcatenate(
HloInstruction* concatenate) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(
- concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
if (operands.size() == 1) {
// Unary concatenates are useless.
ReplaceInstructionIfSameShape(concatenate, operands[0]);
@@ -588,7 +587,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
namespace {
template <typename T>
Status InvertConstant(const HloInstruction& constant, Literal* result) {
- return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return result->Populate<T>([&](absl::Span<const int64> indices) {
return T{1.0} / constant.literal().Get<T>(indices);
});
}
@@ -1249,8 +1248,7 @@ namespace {
//
// Precondition: input_dim_indices is sorted.
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
- const HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
+ const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
@@ -1853,7 +1851,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
ShapeUtil::IsZeroElementArray(reduce->shape())) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index cbce98ef13..182c581ad8 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2226,7 +2226,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
auto out_dims = in_dims;
out_dims[in_channel_idx] = options.f_output_channels;
- auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
+ auto make_shape = [](absl::Span<const int64> dims,
bool minor_to_major_layout) {
if (minor_to_major_layout) {
return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
@@ -2838,8 +2838,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
// a and b are parallel bounds we can either turn into a B F S0 S1 or
// `B S0 S1 F` kind of pattern.
- auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
- int64 a, int64 b) {
+ auto decorate_spatials = [&param](absl::Span<const int64> spatials, int64 a,
+ int64 b) {
std::vector<int64> result;
if (param.prepend_a) {
result.push_back(a);
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index a6889cb171..5c180cbdd4 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -112,11 +112,11 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
return stream_pools_.at(executor).BorrowStream(executor);
}
-Backend::Backend(
- se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
- TransferManager* transfer_manager, ComputationPlacer* computation_placer,
- int intra_op_parallelism_threads)
+Backend::Backend(se::Platform* platform, Compiler* compiler,
+ absl::Span<se::StreamExecutor* const> stream_executors,
+ TransferManager* transfer_manager,
+ ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 4a6a78daf0..fdf8d9cab2 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -149,7 +149,7 @@ class Backend {
private:
struct EigenThreadPoolWrapper;
Backend(se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
+ absl::Span<se::StreamExecutor* const> stream_executors,
TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index a6f77db3b0..b5cf245af6 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -69,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// Inserts conversion HLOs to replace the called computations' BF16
// operands/outputs to F32.
Status ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
@@ -114,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
}
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
std::map<HloComputation*, HloComputation*> cloned_computations;
for (auto& comp : bf16_called_comps) {
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 2fb401c428..545a6ecfb1 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
HloInstruction* hlo) {
auto adjust_computation =
[this, hlo](HloComputation* computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
// Adjust parameters.
CHECK_EQ(operands.size(), computation->num_parameters());
for (int64 i = 0; i < operands.size(); ++i) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index e9751cc269..8bd1533972 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -118,7 +118,7 @@ class BufferAssignmentTest : public HloVerifiedTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
HloModule* module,
- tensorflow::gtl::ArraySlice<const HloInstruction*> instruction_sequence,
+ absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
SequentialHloOrdering::HloModuleSequence module_sequence;
module_sequence[module->entry_computation()] =
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 3079695e96..e5a6c28478 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 1ac950bdd6..61136a3e11 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -50,12 +50,12 @@ class CompileOnlyService : public Service {
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata);
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 1b7a7b36ea..b65dfef9c9 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -479,7 +479,7 @@ class CopyRemover {
// 'values' an entry is created in value_to_node which indicates the
// respective ValueNode representing that value.
void AddValueList(
- tensorflow::gtl::ArraySlice<const HloValue*> values,
+ absl::Span<const HloValue* const> values,
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
ValueNode* tail = nullptr;
ValueNode* head = nullptr;
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
index 408fe0f5bf..1942ea1a2a 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
@@ -40,7 +40,7 @@ std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
}
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
+ absl::Span<const BufferInfo> buffer_infos) {
std::vector<int32> result;
for (int64 i = 0; i < buffer_infos.size(); i++) {
if (buffer_infos[i].is_entry_parameter()) {
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
index 05de70c726..0c5a60f13f 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
@@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment(
// If this function returns V then entry parameter i has buffer allocation index
// V[i].
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
+ absl::Span<const ::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos);
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index bf2efc4d14..9b00f2eaa5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -77,7 +77,7 @@ StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
CpuExecutable::CreateTempArray(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<OwningDeviceMemory> owning_buffers(
@@ -136,7 +136,7 @@ CpuExecutable::CreateTempArray(
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+ absl::Span<const se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
//
@@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction(
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
+ absl::Span<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(),
@@ -245,7 +245,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
TF_ASSIGN_OR_RETURN(
auto result,
@@ -256,7 +256,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
if (hlo_profiling_enabled()) {
return Unimplemented(
"Asynchronous execution on stream with hlo profiling is not yet "
@@ -267,7 +267,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -299,7 +299,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
//
// We also need to change the types of some of the variables we capture:
// run_options needs to change from a pointer to a value type, and arguments
- // needs to change from an ArraySlice into a vector. We use a struct instead
+ // needs to change from a Span into a vector. We use a struct instead
// of a lambda to make this explicit.
struct AsyncRunTask {
CpuExecutable* executable;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 96e53de57e..236de8f14f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -57,12 +57,12 @@ class CpuExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
// This should be called after set_ir_module_string.
const string& ir_module_string() const { return ir_module_string_; }
@@ -92,7 +92,7 @@ class CpuExecutable : public Executable {
// exists) must out-live the task.
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile);
// Creates an array suitable for passing as the "temps" argument to the JIT
@@ -112,21 +112,20 @@ class CpuExecutable : public Executable {
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
- Status ExecuteComputeFunction(
- const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- HloExecutionProfile* hlo_execution_profile);
+ Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
+ absl::Span<const se::DeviceMemoryBase> buffers,
+ HloExecutionProfile* hlo_execution_profile);
// Creates a ScopedShapedBuffer for holding the result of the computation,
// moving buffers out of allocated_buffers and into the result as appropriate.
// The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
+ absl::Span<OwningDeviceMemory> buffers);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 0df2abf001..5519a43b2f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -179,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
- tensorflow::gtl::ArraySlice<int64> dimensions(
+ absl::Span<const int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
TF_ASSIGN_OR_RETURN(
@@ -225,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data) {
+ absl::Span<const std::pair<void*, int64>> buffer_data) {
return TransferBuffersFromOutfeedInternal(executor, buffer_data,
/*is_tuple=*/true);
}
@@ -238,8 +238,7 @@ StatusOr<Shape> CpuTransferManager::TransferArrayBufferFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple) {
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple) {
std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
for (auto b : buffer_data) {
int64 size = b.second;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 7b938e9fd7..6927edff86 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager {
// Helper that transfers a tuple of element buffers from the device's outfeed.
StatusOr<Shape> TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data);
+ absl::Span<const std::pair<void*, int64>> buffer_data);
// Helper that transfers an array buffer from the device's outfeed.
StatusOr<Shape> TransferArrayBufferFromOutfeed(se::StreamExecutor* executor,
@@ -68,8 +68,7 @@ class CpuTransferManager : public GenericTransferManager {
// for the given buffers.
StatusOr<Shape> TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple);
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple);
TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager);
};
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index dd060f54a2..99fa707c95 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -80,7 +80,7 @@ class MemoryTile {
// `minor_dim_offset`}.
//
// Note: `major_dim_offset` is a parameter to the constructor.
- void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+ void StoreTile(absl::Span<llvm::Value* const> tile,
llvm::Value* minor_dim_offset) const {
CHECK_EQ(tile.size(), pointers_.size());
for (int64 i = 0; i < pointers_.size(); i++) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 1c828cc02c..7839d97317 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -506,8 +506,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- absl::string_view name) {
+ absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
@@ -1455,7 +1454,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment) {
ShardedVector accumulator;
accumulator.reserve(accumulator_type.size());
@@ -1551,7 +1550,7 @@ void IrEmitter::EmitShardedVectorStore(
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- gtl::ArraySlice<int64> dimensions, HloComputation* function,
+ absl::Span<const int64> dimensions, HloComputation* function,
string* failure_reason) {
if (!ReductionPreservesLayout(*reduce)) {
return false;
@@ -1701,7 +1700,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1758,7 +1757,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
string vectorization_failure_reason;
@@ -2113,7 +2112,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
- gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
+ absl::Span<HloInstruction* const> operands(custom_call->operands());
absl::string_view custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
@@ -2233,7 +2232,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
}
StatusOr<bool> IrEmitter::EmitFastConcatenate(
- HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
+ HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
string* failure_reason) {
if (ShouldEmitParallelLoopFor(*concatenate)) {
*failure_reason =
@@ -2369,7 +2368,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
}
Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
- gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
string failure_reason;
TF_ASSIGN_OR_RETURN(
bool successful,
@@ -2800,8 +2799,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
Status IrEmitter::ElementTypesSameAndSupported(
const HloInstruction& instruction,
- gtl::ArraySlice<const HloInstruction*> operands,
- gtl::ArraySlice<PrimitiveType> supported_types) {
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types) {
for (auto operand : operands) {
TF_RET_CHECK(
ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
@@ -2831,8 +2830,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
}
llvm::Value* IrEmitter::EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) {
const Shape& return_shape = callee.root_instruction()->shape();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index f98891246b..015724b65d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -111,7 +111,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit code to map one element according to `map_instr`.
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ absl::Span<llvm::Value* const> elemental_operands,
absl::string_view name);
protected:
@@ -252,10 +252,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
//
// `parameters` holds the *scalar values* that need to be passed to the
// callee. The return value is the scalar returned by the callee.
- llvm::Value* EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- absl::string_view name);
+ llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
+ absl::Span<llvm::Value* const> parameters,
+ absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
@@ -271,8 +270,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// match and are of one of the given supported types.
Status ElementTypesSameAndSupported(
const HloInstruction& instruction,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types);
// Emit IR to perform a computation for every element in the given target op.
// This produces a series of nested loops (one for each dimension of the op's
@@ -319,10 +318,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// concepts that generalize over other vectorizable operations. We should
// consider pulling out these abstractions into a VectorizingIrEmitter or
// something similar.
- StatusOr<bool> EmitVectorizedReduce(
- HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
- string* failure_reason);
+ StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
+ HloInstruction* arg,
+ HloInstruction* init_value,
+ absl::Span<const int64> dimensions,
+ HloComputation* function,
+ string* failure_reason);
// We'd like to keep one or two one cache-line's worth of data in registers
// without generating IR with illegal (e.g. excessively large or
@@ -372,16 +373,15 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment);
// Tries to emit a fast concatenate operation using memcpy. Returns true if
// successful, and false on failure. On failure, sets "failure_reason" to a
// string describing why it could not emit a fast concatenate.
- StatusOr<bool> EmitFastConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- string* failure_reason);
+ StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
+ absl::Span<HloInstruction* const> operands,
+ string* failure_reason);
// Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
// from the address "source" to the address "target".
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 784045313d..3ecf4b69b7 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -200,10 +200,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// Returns an array of compute function call arguments (including parameter
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, absl::string_view name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
+ llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
if (parameter_addresses.empty()) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index ee7595f6e9..076ca219bc 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -115,10 +115,10 @@ class IrFunction {
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, absl::string_view name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
+ llvm::Value* profile_counters_arg);
// Emits a call to a runtime fork/join function which dispatches parallel
// calls to 'parallel_function' (and joins threads before returning).
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 962ea69c09..1bd4b59dd6 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -428,7 +428,7 @@ std::vector<llvm::Value*> TileVariable::Get() const {
return result;
}
-void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
+void TileVariable::Set(absl::Span<llvm::Value* const> value) {
CHECK_EQ(value.size(), storage_.size());
for (int64 i = 0, e = value.size(); i < e; i++) {
storage_[i].Set(value[i]);
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index c728f6df0a..3dfe941a3a 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -324,7 +324,7 @@ class TileVariable {
std::vector<llvm::Value*> initial_value);
std::vector<llvm::Value*> Get() const;
- void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
+ void Set(absl::Span<llvm::Value* const> value);
private:
std::vector<VectorVariable> storage_;
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
index 47543b2082..b9e47f5aad 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
@@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() {
}
void XfeedQueueManager::EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers) {
+ absl::Span<XfeedBuffer* const> buffers) {
tensorflow::mutex_lock l(mu_);
bool was_empty = enqueued_buffers_.empty();
for (XfeedBuffer* b : buffers) {
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
index b4ace23260..fac1722b10 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
@@ -63,8 +63,7 @@ class XfeedQueueManager {
// called when the buffer will no longer be accessed by the XfeedManager,
// either as a result of a call to Reset or because the runtime has dequeued
// and used the buffer.
- void EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers);
+ void EnqueueBuffersAtomically(absl::Span<XfeedBuffer* const> buffers);
// Blocks until the queue is non-empty, then returns the buffer at the head of
// the queue. Sets the current buffer to be the returned buffer. It is an
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
index 1d0297cfbf..edbcb25247 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.cc
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -25,7 +25,7 @@ namespace xla {
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors)
+ absl::Span<se::StreamExecutor* const> stream_executors)
: DeviceMemoryAllocator(platform),
stream_executors_(stream_executors.begin(), stream_executors.end()) {}
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h
index d87b86caf0..28a3539373 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.h
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.h
@@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
public:
StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<se::StreamExecutor* const> stream_executors);
StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
bool retry_on_failure) override;
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index def42f9c77..4bb1e071d8 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -856,7 +856,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(b_->getFloatTy(), f);
};
- auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
+ auto multiply_add = [&](absl::Span<const float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
coefficients.remove_prefix(1);
@@ -893,7 +893,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
SetToFirstInsertPoint(if_data.true_block, b_);
{
llvm::Value* lw = FSub(w, getFloat(2.5f));
- tensorflow::gtl::ArraySlice<float> lq{
+ absl::Span<const float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
@@ -908,7 +908,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- tensorflow::gtl::ArraySlice<float> gq{
+ absl::Span<const float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 5ab0756219..1b3be199f6 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -28,8 +28,7 @@ using absl::nullopt;
class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 78edf918a4..47c56e2f7f 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -26,13 +26,12 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
-using tensorflow::gtl::ArraySlice;
namespace xla {
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
- ArraySlice<const ServiceExecutableRunOptions> run_options,
- ArraySlice<ArraySlice<const ShapedBuffer*>> arguments) {
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
TF_RET_CHECK(run_options.size() == arguments.size());
std::vector<ScopedShapedBuffer> return_values;
@@ -63,7 +62,7 @@ StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
se::Stream* stream = run_options->stream();
std::unique_ptr<se::Timer> timer;
if (profile != nullptr) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 6e055edc03..4b8d955b28 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -81,14 +81,14 @@ class Executable {
// Returns a shaped buffer containing the result of the computation.
virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
// soon as all of the operations are enqueued for launch on the stream.
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) = 0;
+ absl::Span<const ShapedBuffer* const> arguments) = 0;
// Starts the given program executing on the given stream/executor.
//
@@ -119,11 +119,8 @@ class Executable {
// run_options[i]->stream() and the returned value is at index i of the
// returned vector.
virtual StatusOr<std::vector<ScopedShapedBuffer>> ExecuteOnStreams(
- tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions>
- run_options,
- tensorflow::gtl::ArraySlice<
- tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- arguments);
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments);
// Populates `hlo_execution_profile` from `executor`. This is implicit in any
// Execute* API call that takes a hlo_execution_profile argument, but must be
@@ -139,7 +136,7 @@ class Executable {
// given ExecutionProfile if non-null.
StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
// Returns the ExecutionProfile from executing on the device. This includes
// the number of cycles taken for the computation or the compilation time.
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 3f1a881372..cb86c98579 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
HloInstruction* start_indices, int64 index_vector_dim) {
@@ -225,7 +224,7 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count,
+ absl::Span<const int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
@@ -244,7 +243,7 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
- HloInstruction* accumulator, ArraySlice<int64> offset_dims,
+ HloInstruction* accumulator, absl::Span<const int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 0ce2db907b..4ed91ef187 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const {
}
Status GenericTransferManager::WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) {
TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
@@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed(
}
Status GenericTransferManager::ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>
+ absl::Span<se::StreamExecutor* const>
/*executors*/) {
return Unimplemented(
"Device reset is not yet supported on this platform (b/30481585)");
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 6c1a21587a..86c8b1c145 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager {
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
- Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+ Status ResetDevices(absl::Span<se::StreamExecutor* const> executors) override;
int64 GetByteSizeRequirement(const Shape& shape) const override;
protected:
Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index dbdf8e7a0e..2af31a52f9 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -204,9 +204,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
if (allocator_ != nullptr) {
allocator = allocator_;
} else {
- se_allocator.emplace(
- stream_exec_->platform(),
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>({stream_exec_}));
+ se_allocator.emplace(stream_exec_->platform(),
+ absl::Span<se::StreamExecutor* const>({stream_exec_}));
allocator = &*se_allocator;
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 57a3a43a6f..c1aaa4bf04 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(
compute_nested_(std::move(compute_nested)) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// The libdevice math functions differentiate between "double" and "float" by
// appending an 'f' to the function's name. libdevice doesn't have f16 math
// functions, so we convert the operands to f32 before calling the function
@@ -119,10 +117,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// llvm intrinsics differentiate between half/float/double functions via
// the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
@@ -144,10 +140,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
@@ -290,11 +284,9 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes) {
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 91942785d2..43f1f208bf 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -38,9 +38,9 @@ namespace gpu {
class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
- // given an ArraySlice of its input elements.
+ // given a Span of its input elements.
using NestedComputer = std::function<StatusOr<llvm::Value*>(
- const HloComputation&, tensorflow::gtl::ArraySlice<llvm::Value*>)>;
+ const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
llvm::Module* module, llvm::IRBuilder<>* b,
@@ -96,37 +96,29 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
// Emits IR to call a device function named "callee_name" on the given
// operand. Returns the IR value that represents the return value.
llvm::Value* EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes);
// Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a libdevice function of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index 11549cdac5..ca4a605af5 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) {
} // namespace
-FftThunk::FftThunk(FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length,
+FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 4adec7ee54..2be50e08bd 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -62,7 +62,7 @@ class FftThunk : public Thunk {
public:
// Constructs a thunk for launching an FFT on a stream.
// Semantics of null hlo_instruction argument are as in Thunk.
- FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length,
+ FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 71a02e70df..31a9f9b1be 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -234,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
@@ -325,7 +325,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode.
return Unimplemented(
"Asynchronous execution on stream is not yet supported on GPU.");
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 627a05e240..b3765adf5e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -78,12 +78,12 @@ class GpuExecutable : public Executable {
// match the compute capability passed to this object's constructor.
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
private:
// If `block_host_until_done` is false, execution will not block the host
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 0e205b9c02..51627402b4 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -35,8 +35,8 @@ using absl::StrAppend;
using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) {
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos) {
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
index eee40b0e91..5b05ed812e 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -45,8 +45,8 @@ class HloToIrBindings {
alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {}
void EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos);
// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index f544bcc919..9c90f4d46b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -216,7 +216,7 @@ bool IsReductionToVector(const HloInstruction& reduce) {
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
llvm::Value* EmitPrintf(absl::string_view fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
for (auto argument : arguments) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index a35e250101..d242897e16 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -127,7 +127,7 @@ bool IsReductionToVector(const HloInstruction& reduce);
// Emits call to "vprintf" with given format and arguments.
llvm::Value* EmitPrintf(absl::string_view fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder);
// Emits code to shuffle data between threads of a warp. This has the same
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index bdf6aadde6..ffca5d6549 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -141,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
Status IrEmitter::EmitCallToNestedComputation(
const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output) {
+ absl::Span<llvm::Value* const> operands, llvm::Value* output) {
TF_RET_CHECK(nested_computation.num_parameters() > 0);
llvm::Function*& emitted_function =
computation_to_ir_function_[&nested_computation];
@@ -633,7 +633,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
@@ -748,7 +748,7 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
+ absl::Span<llvm::Value* const> parameter_elements) {
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 3673b9f58d..bc2b04ace5 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -143,9 +143,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emits a call in IR to the given nested computation with the given operands
// and output. If no IR function has been previously emitted for the
// computation, also emits such a function.
- Status EmitCallToNestedComputation(
- const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output);
+ Status EmitCallToNestedComputation(const HloComputation& nested_computation,
+ absl::Span<llvm::Value* const> operands,
+ llvm::Value* output);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
@@ -199,7 +199,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Value*> ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
+ absl::Span<llvm::Value* const> parameter_elements);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 860dd0b50f..3ab79197e2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -94,7 +94,6 @@ using absl::optional;
using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
-using tensorflow::gtl::ArraySlice;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
@@ -176,7 +175,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
+ absl::Span<const BufferAllocation* const> args) {
// Compute the kernel name. The opcode string may contain "-" which cannot be
// in a PTX function name, so sanitize the name before uniquifying it.
string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
@@ -556,10 +555,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
- ArraySlice<HloInstruction*> output_instructions =
+ absl::Span<HloInstruction* const> output_instructions =
root->opcode() == HloOpcode::kTuple
? root->operands()
- : ArraySlice<HloInstruction*>(&root, 1);
+ : absl::Span<HloInstruction* const>(&root, 1);
// For multi-output fusion emit an initializer for each tuple element.
// Otherwise it's sufficient to just initialize the single output.
@@ -718,8 +717,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
const HloInstruction* reduce, const IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
for (int i = 0; i != extra_output_gens.size(); ++i) {
const HloInstruction* output = reduce->parent()->FusionInstruction();
@@ -736,12 +734,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
@@ -951,12 +948,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Divide the input matrix into tiles of size KxL. For example, when the
// input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
@@ -1240,12 +1236,11 @@ static std::pair<int64, int64> ComputeTilingSchemeForReduction(
Status IrEmitterUnnested::EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// A naive algorithm is:
// 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
@@ -1593,13 +1588,12 @@ Status IrEmitterUnnested::EmitRowReduction(
// elementwise.
Status IrEmitterUnnested::EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
@@ -1694,7 +1688,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
}
auto input = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions());
+ absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
HloComputation* reducer = reduce->to_apply();
// HandleReduce specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires an initializer thunk that
@@ -2570,7 +2564,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// Are all the bytes of this scalar equal to 0? If so, we can create a
// MemzeroThunk.
- ArraySlice<uint8> literal_bytes(
+ absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
@@ -2880,7 +2874,7 @@ int IrEmitterUnnested::ConstructIrArrayForInputs(
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<IrArray>* output_in_reduced_shape_arrays) {
int64 num_outputs = 1;
@@ -2907,7 +2901,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<IrArray>* param_in_reduced_shape_arrays) {
int64 num_params = hlo.operands().size();
@@ -3048,8 +3042,8 @@ void EmitTiledElementalCodeWithBoundsCheck(
// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
// to launch fewer blocks so each transposes many tiles.
LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
- HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ HloInstruction* hlo, absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids) {
// Parameters for the tiling algorithm.
constexpr int64 kTileSize = 32;
constexpr int64 kNumRows = 4;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 5254419907..084462330e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter {
// This kernel takes as arguments pointers to the given buffer allocations.
llvm::Function* BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
+ absl::Span<const BufferAllocation* const> args);
// Helper for writing extra outputs from inside a reduce kernel.
Status EmitExtraOutputsForReduce(
const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// EmitColumnReduction and EmitRowReduction emit code for column and row
@@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
@@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
Status EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Figures out whether `reduce` is a row or column reduction, and which
@@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter {
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
@@ -195,10 +190,9 @@ class IrEmitterUnnested : public IrEmitter {
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
// returns the launch dimensions for the kernel. This is a helper to support
// the implementation of CheckAndEmitHloWithTile021.
- LaunchDimensions EmitHlo021Tile(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
+ LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
+ absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids);
// Generates the IrArray for each output of hlo and returns the number of
// outputs.
int ConstructIrArrayForOutputs(const HloInstruction& hlo,
@@ -214,7 +208,7 @@ class IrEmitterUnnested : public IrEmitter {
int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
// For each input of the `hlo` instruction, checks its value in
@@ -226,7 +220,7 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index 3259eaa2a2..878b0b96a1 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -27,10 +27,10 @@ limitations under the License.
namespace xla {
namespace gpu {
-KernelThunk::KernelThunk(
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction,
- int unroll_factor)
+KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
+ const string& kernel_name,
+ const HloInstruction* hlo_instruction,
+ int unroll_factor)
: Thunk(Kind::kKernel, hlo_instruction),
args_(args.begin(), args.end()),
kernel_name_(kernel_name),
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index d751de50ad..480f473037 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -47,7 +47,7 @@ class KernelThunk : public Thunk {
// Constructs a thunk for the given kernel.
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
- KernelThunk(tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
+ KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name, const HloInstruction* hlo_instruction,
int unroll_factor);
KernelThunk(const KernelThunk&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index 79f7d31816..fa84d77223 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-using tensorflow::gtl::ArraySlice;
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
@@ -42,7 +41,7 @@ static constexpr double kMaxBytesTouchedIncrease = 1.2;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
-static Shape PadShape(Shape s, ArraySlice<int64> dims) {
+static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
int64 new_dim_to_pad_size =
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index ca57cacb98..8154d75d23 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter(
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
: LoopEmitter(target_element_generator, target_arrays, b),
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index cc7da2e73b..f32ea1ce4c 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -47,11 +47,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
//
// This is used in multi-output fusion. target_element_generator should
// produce a struct with N elements, one for each of target_arrays.
- ParallelLoopEmitter(
- const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
- int unroll_factor = 1);
+ ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
+ const LaunchDimensions& launch_dimensions,
+ llvm::IRBuilder<>* b, int unroll_factor = 1);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 2d5735d6c4..a3a03b53f8 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -34,8 +34,7 @@ namespace gpu {
// issue (b/31336476).
class TupleThunk : public Thunk {
public:
- TupleThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice>
- tuple_element_buffers,
+ TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
const BufferAllocation::Slice& dest_buffer,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kTuple, hlo_instruction),
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h
index 4873463b2e..a88c87e46c 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.h
+++ b/tensorflow/compiler/xla/service/hlo_buffer.h
@@ -84,7 +84,7 @@ class HloBuffer {
return a->id() == b->id();
}
- HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
+ HloBuffer(Id id, absl::Span<const HloValue* const> values)
: id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c2d0673f49..fe7f2be888 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -558,7 +558,7 @@ HloComputation::CreateFromProto(
}
void HloComputation::FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction) {
CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
HloInstruction* root = instructions_to_fuse.front();
@@ -577,7 +577,7 @@ void HloComputation::FuseInstructionsInto(
}
HloInstruction* HloComputation::CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind) {
HloInstruction* root = instructions_to_fuse.front();
HloInstruction* fusion_instruction = AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 59016624f7..daafb711fd 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -237,7 +237,7 @@ class HloComputation {
// removed if they have no uses after fusion (this is necessarily true for at
// least the root).
HloInstruction* CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind);
// Create a deep copy of the given instruction and return the instruction
@@ -385,7 +385,7 @@ class HloComputation {
//
// Pre-condition: fusion_instruction's opcode is kFusion.
void FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction);
// Internal helper for recursive copying of an instruction. Creates and
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 7cd1481a8a..07cd1efc12 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -105,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
TEST_F(HloConstantFoldingTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
- tensorflow::gtl::ArraySlice<int64> dimensions;
- tensorflow::gtl::ArraySlice<int64> concat_sizes;
+ absl::Span<const int64> dimensions;
+ absl::Span<const int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
@@ -196,7 +196,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
bool matched = true;
root->literal().EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ [&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
});
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 131846794d..19ffb465c0 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -24,7 +24,6 @@ limitations under the License.
namespace xla {
using absl::StrCat;
-using tensorflow::gtl::ArraySlice;
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
@@ -50,9 +49,9 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
- ArraySlice<int64> start_indices,
- ArraySlice<int64> limit_indices,
- ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
operand->shape(), start_indices,
@@ -74,7 +73,7 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
- ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(
Shape transpose_shape,
@@ -91,15 +90,15 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
}
StatusOr<HloInstruction*> MakeReshapeHlo(
- ArraySlice<int64> result_shape_dim_bounds, HloInstruction* operand) {
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_dim_bounds);
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> MakeDynamicSliceHlo(HloInstruction* operand,
- HloInstruction* start_indices,
- ArraySlice<int64> slice_sizes) {
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(
+ HloInstruction* operand, HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, start_indices->parent());
TF_ASSIGN_OR_RETURN(
@@ -125,8 +124,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
}
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand, ArraySlice<int64> broadcast_dimensions,
- ArraySlice<int64> result_shape_bounds) {
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds) {
HloComputation* computation = operand->parent();
Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_bounds);
@@ -146,8 +145,8 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
}
-StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
- int64 dimension) {
+StatusOr<HloInstruction*> MakeConcatHlo(
+ absl::Span<HloInstruction* const> operands, int64 dimension) {
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
@@ -176,9 +175,8 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation) {
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation) {
CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
HloComputation* computation = operands.front()->parent();
std::vector<const Shape*> operand_shapes;
@@ -235,7 +233,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
}
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, ArraySlice<int64> expanded_dims) {
+ HloInstruction* operand, absl::Span<const int64> expanded_dims) {
CHECK_GT(operand->shape().dimensions_size(), 0);
CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
@@ -251,8 +249,8 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
- ArraySlice<int64> dims_to_elide) {
+StatusOr<HloInstruction*> ElideDegenerateDims(
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
@@ -277,7 +275,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
}
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
CHECK(absl::c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
@@ -327,7 +325,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
@@ -336,7 +334,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
}
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- ArraySlice<const Shape*> domain, const Shape& range,
+ absl::Span<const Shape* const> domain, const Shape& range,
absl::string_view name) {
HloComputation::Builder b{string(name)};
int64 param_idx = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 1bc6d09b45..a1c4b374d1 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -40,10 +40,10 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
// Creates a slice HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeSliceHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
@@ -53,8 +53,8 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeTransposeHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Creates a reshape HLO instruction and adds it to the computation containing
// `operand`.
@@ -62,15 +62,14 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
HloInstruction* operand);
StatusOr<HloInstruction*> MakeReshapeHlo(
- tensorflow::gtl::ArraySlice<int64> result_shape_dim_bounds,
- HloInstruction* operand);
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
// Creates a dynamic-slice HLO instruction and adds it to the computation
// containing `operand` and `start_indices` (`operand` and `start_indices` must
// be in the same computation).
StatusOr<HloInstruction*> MakeDynamicSliceHlo(
HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Creates a dynamic-update-slice HLO instruction and adds it to the computation
// containing `operand`, `update` and `start_indices` (`operand`, `update` and
@@ -82,9 +81,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
// Creates a broadcast HLO instruction and adds it to the computation containing
// `operand`.
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions,
- tensorflow::gtl::ArraySlice<int64> result_shape_bounds);
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds);
// Creates a GetTupleElement HLO instruction and adds it to the computation
// containing `operand`.
@@ -95,7 +93,7 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
// containing `operands` (`operands` must be non-empty and every element must be
// contained in the same computation).
StatusOr<HloInstruction*> MakeConcatHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
+ absl::Span<HloInstruction* const> operands, int64 dimension);
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
@@ -104,9 +102,8 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
@@ -138,7 +135,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
// For instance if `operand` has shape f32[200,9,7] and expanded_dims is
// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
+ HloInstruction* operand, absl::Span<const int64> expanded_dims);
// Elides (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_elide` from `operand`. Every dimension in
@@ -148,7 +145,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
StatusOr<HloInstruction*> ElideDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide);
// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_insert` into `operand`. The dimensions in
@@ -158,7 +155,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(
// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert);
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
@@ -171,12 +168,12 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
// broadcast instruction is emitted into `computation`.
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a HLO computation that takes arguments of type `domain` and produces
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range,
+ absl::Span<const Shape* const> domain, const Shape& range,
absl::string_view name);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 662f008205..eb6affadc8 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -24,15 +24,13 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
class HloCreationUtilsTest : public HloVerifiedTestBase {
protected:
- HloModule* CreateModuleWithProgramShape(PrimitiveType primitive_type,
- ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims,
- HloInstruction** param,
- HloComputation** entry_computation) {
+ HloModule* CreateModuleWithProgramShape(
+ PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
+ absl::Span<const int64> output_shape_dims, HloInstruction** param,
+ HloComputation** entry_computation) {
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
Shape output_shape =
ShapeUtil::MakeShape(primitive_type, output_shape_dims);
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 3376d170e6..6a63681996 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -46,8 +46,7 @@ namespace {
//
// In this case, we should be able to reuse p0 and output, although p0 has
// multiple uses.
-bool MultiDynamicSliceUseShareSameIndices(
- tensorflow::gtl::ArraySlice<HloUse> uses) {
+bool MultiDynamicSliceUseShareSameIndices(absl::Span<const HloUse> uses) {
if (uses.empty()) {
return false;
}
@@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const {
bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK(ssa_form_);
VLOG(4) << "Phi(" << instruction->name() << ")";
VLOG(5) << "instruction value set = "
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index a1678d4943..6d5c375d6d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -202,7 +202,7 @@ class HloDataflowAnalysis {
// the given instruction. If skip_top_level is true, then the top level of the
// value set of 'instruction' is not modified.
bool Phi(HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ absl::Span<const InstructionValueSet* const> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index c25869f87b..d316645a0b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,7 +53,6 @@ namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
@@ -97,10 +96,11 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<OperandT>(multi_index),
- rhs_literal.Get<OperandT>(multi_index));
- }));
+ TF_RETURN_IF_ERROR(
+ result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<OperandT>(multi_index),
+ rhs_literal.Get<OperandT>(multi_index));
+ }));
return std::move(result);
}
@@ -127,10 +127,11 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<complex64>(multi_index),
- rhs_literal.Get<complex64>(multi_index));
- }));
+ TF_RETURN_IF_ERROR(
+ result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<complex64>(multi_index),
+ rhs_literal.Get<complex64>(multi_index));
+ }));
return std::move(result);
}
@@ -194,7 +195,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloModule& module, ArraySlice<LiteralPtr> arg_literals) {
+ const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
evaluated_.clear();
@@ -211,7 +212,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloComputation& computation, ArraySlice<LiteralPtr> arg_literals) {
+ const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
XLA_VLOG_LINES(
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
@@ -228,7 +230,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
+ HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
evaluated_.clear();
@@ -390,7 +392,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
}
Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
- ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
@@ -588,7 +590,7 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
- int64 output_rank, ArraySlice<int64> slice_sizes,
+ int64 output_rank, absl::Span<const int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
@@ -661,11 +663,12 @@ class OutputBatchIndexToInputIndex {
// same storage for all invocations.
//
// This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
@@ -674,7 +677,7 @@ class OutputBatchIndexToInputIndex {
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
void PropagateOutputIndexGatherDimsToIndexVectorIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
if (!output_dim_is_batch_dims_[i]) {
@@ -729,7 +732,7 @@ class OutputBatchIndexToInputIndex {
// The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
@@ -779,9 +782,10 @@ class OutputOffsetIndexToInputIndex {
// result (input_index_), mutating it in place.
//
// This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexWindowDimsToInputIndex(output_index);
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding output dimension index,
@@ -794,7 +798,7 @@ class OutputOffsetIndexToInputIndex {
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
void PropagateOutputIndexWindowDimsToInputIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_output_index_[i] != -1) {
input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
@@ -810,7 +814,7 @@ class OutputOffsetIndexToInputIndex {
// PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
};
@@ -872,11 +876,11 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
const Shape& operand_shape = operand.shape();
auto gather_inner_loop_body =
- [&](ArraySlice<int64> output_window_index,
- ArraySlice<int64> input_gather_index,
- ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+ [&](absl::Span<const int64> output_window_index,
+ absl::Span<const int64> input_gather_index,
+ absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_window_index,
+ absl::Span<const int64> input_window_index,
output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
@@ -909,8 +913,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
};
auto gather_outer_loop_body =
- [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index,
+ [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
+ TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
shape, offset_indices_iteration_space,
@@ -1170,12 +1174,11 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_values.push_back(key_value.second);
}
auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<KeyType>(result_keys));
+ result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
auto result_values_literal =
absl::make_unique<Literal>(values_literal.shape());
result_values_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ValueType>(result_values));
+ absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
@@ -1311,26 +1314,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
// Explicit instantiation of templatized Evaluate* methods.
//
template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloModule& module,
- ArraySlice<const Literal*> arg_literals);
+HloEvaluator::Evaluate<const Literal*>(
+ const HloModule& module, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module, ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ const HloModule& module,
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloComputation& computation,
- ArraySlice<const Literal*> arg_literals);
+template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
+ const Literal*>(const HloComputation& computation,
+ absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
const HloComputation& computation,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(HloInstruction* instruction,
- ArraySlice<const Literal*> arg_literals);
+HloEvaluator::Evaluate<const Literal*>(
+ HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
HloInstruction* instruction,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 980a7fb9fa..3feb4e626f 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -51,8 +51,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -75,7 +74,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -87,8 +86,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
@@ -229,8 +227,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index e3eb60a851..626daa527b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -60,7 +60,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
}
std::unique_ptr<Literal> Evaluate(
- tensorflow::gtl::ArraySlice<const Literal*> arg_literals = {}) {
+ absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -344,7 +344,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
result->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ [&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
});
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index d35163ebb8..980e343035 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -97,7 +97,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
double GetAsDouble(const Literal& literal,
- tensorflow::gtl::ArraySlice<int64> input_index) {
+ absl::Span<const int64> input_index) {
return static_cast<double>(literal.Get<NativeT>(input_index));
}
@@ -109,7 +109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
double GetAsDouble(const Literal& literal,
- tensorflow::gtl::ArraySlice<int64> input_index) {
+ absl::Span<const int64> input_index) {
LOG(FATAL) << "Trying to get complex literal as double: "
<< literal.ToString();
}
@@ -980,8 +980,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1048,8 +1048,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](
- tensorflow::gtl::ArraySlice<int64> out_index) {
+ rhs_literal_data](absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1211,8 +1210,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result = absl::make_unique<Literal>(dot->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> result_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1261,9 +1260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
auto result = absl::make_unique<Literal>(pad->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
- return scalar;
- }));
+ [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
parent_->GetEvaluatedLiteralFor(pad->operand(0));
@@ -1276,7 +1273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// corresponding index of the resulting padded literal.
const PaddingConfig& pad_config = pad->padding_config();
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](absl::Span<const int64> input_index) {
for (auto i = 0; i < input_index.size(); ++i) {
// Interior padding occurs logically before edge padding, so in the case
// of negative edge padding elements are removed from the
@@ -1427,8 +1424,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
std::vector<std::unique_ptr<Literal>> arg_literals;
arg_literals.reserve(operands.size());
@@ -1539,8 +1536,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return SafeLess<ReturnT>(a, b);
});
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
return result_literal;
};
@@ -1582,7 +1578,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
int64 num_args = reduce->inputs().size();
bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape());
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
absl::InlinedVector<const Shape*, 1> operand_shapes;
@@ -1650,7 +1646,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 input = 0; input < num_args; ++input) {
TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ [&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
}
@@ -1668,7 +1664,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
IsScalarAdd(function)) {
CHECK_EQ(num_args, 1);
double computed_result = 0;
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](absl::Span<const int64> input_index) {
computed_result +=
GetAsDouble<ReturnT>(*arg_literals[0], input_index);
return true;
@@ -1677,8 +1673,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_dim_counts, arg_dim_steps, func);
return static_cast<ReturnT>(computed_result);
}
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
- -> StatusOr<bool> {
+ auto func =
+ [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
absl::InlinedVector<ReturnT, 1> arg_values(num_args);
for (int64 i = 0; i < num_args; ++i) {
arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index);
@@ -1767,9 +1763,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize result array with the init value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
- return init_scalar;
- }));
+ [&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
for (const auto& window_dimension : window.dimensions()) {
@@ -1902,8 +1896,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
auto result = absl::make_unique<Literal>(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -2049,12 +2043,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// same storage for all invocations.
//
// This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
@@ -2063,7 +2057,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// update the dim_numbers.index_vector_dim() dimension -- that's the
// dimension we iterate over in FetchIndexVector.
void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = update_index.size(); i < e; i++) {
if (!update_dim_is_scatter_dims_[i]) {
@@ -2118,7 +2112,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// The index vector fetched from scatter_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
@@ -2172,10 +2166,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// result (input_index_), mutating it in place.
//
// This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexWindowDimsToInputIndex(update_index);
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding update dimension index,
@@ -2188,7 +2182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Propagates window dimensions from the update index to input_index_ by
// mutating input_index_ in place.
void PropagateUpdateIndexWindowDimsToInputIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_update_index_[i] != -1) {
input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
@@ -2204,7 +2198,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// PropagateUpdateIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_update_index_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
};
@@ -2247,12 +2241,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::unique_ptr<Literal> result = operand.CloneToUnique();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index,
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
- tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_window_index,
+ absl::Span<const int64> input_scatter_index,
+ absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_window_index,
+ absl::Span<const int64> input_window_index,
update_window_index_to_input_index(update_window_index));
for (int i = 0, e = update_index.size(); i < e; i++) {
update_index[i] = update_scatter_index[i] + update_window_index[i];
@@ -2301,14 +2294,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
};
auto scatter_outer_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
+ absl::Span<const int64> input_scatter_index,
update_scatter_index_to_input_index(update_scatter_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
updates_shape, window_indices_iteration_space,
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index) {
+ [&](absl::Span<const int64> update_window_index) {
return scatter_inner_loop_body(
update_window_index, input_scatter_index, update_scatter_index);
}));
@@ -2336,7 +2328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 rank = ShapeUtil::Rank(operand->shape());
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ auto func = [&](absl::Span<const int64> out_index) {
DimensionVector operand_index(rank);
for (int64 i = 0; i < rank; ++i) {
operand_index[i] =
@@ -2607,7 +2599,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// bound, call `f` with the base index.
static void IterateThroughWindow(
const Shape& window_shape, const Window& window, const Shape& base_shape,
- const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const absl::Span<const int64>& window_count_index,
const std::function<void(const std::vector<int64>&)>& f) {
const int64 rank = ShapeUtil::Rank(base_shape);
DimensionVector window_index(rank);
@@ -2647,8 +2639,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> operand_indices(start.size());
auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2679,7 +2671,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> result_index(rank, 0);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
+ auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
result->Set<ReturnT>(result_index,
@@ -2733,8 +2725,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2770,8 +2762,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index b747a4ea5f..bd0b6af10d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
- tensorflow::gtl::ArraySlice<int64>(fft_length));
+ absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kSend:
@@ -519,13 +519,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
+ absl::Span<HloInstruction* const> parameters) {
return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
if (opcode == HloOpcode::kCopy) {
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
@@ -627,13 +627,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK_EQ(HloOpcode::kTuple, opcode);
return CreateNary(shape, opcode, operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation) {
return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
@@ -648,7 +648,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length) {
+ absl::Span<const int64> fft_length) {
return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
fft_length);
}
@@ -692,7 +692,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id) {
@@ -702,7 +702,7 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups) {
return absl::make_unique<HloAllToAllInstruction>(shape, operands,
replica_groups);
@@ -764,12 +764,12 @@ HloInstruction::CreateCollectivePermute(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK(!operands.empty());
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
@@ -815,16 +815,15 @@ HloInstruction::CreateCollectivePermute(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloDynamicSliceInstruction>(
shape, operand, start_indices, slice_sizes);
}
@@ -843,7 +842,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension) {
return absl::make_unique<HloConcatenateInstruction>(shape, operands,
dimension);
@@ -868,7 +867,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
@@ -876,9 +875,9 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
@@ -936,7 +935,7 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return absl::make_unique<HloBroadcastInstruction>(shape, operand,
broadcast_dimensions);
}
@@ -1014,7 +1013,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
@@ -1032,7 +1031,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation) {
return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
fusion_computation);
@@ -1090,7 +1089,7 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
@@ -1102,14 +1101,14 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target) {
return absl::make_unique<HloCustomCallInstruction>(shape, operands,
custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
+ absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
for (auto element : elements) {
element_shapes.push_back(element->shape());
@@ -1121,7 +1120,7 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloGatherInstruction>(
shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
@@ -1149,8 +1148,7 @@ bool HloInstruction::HasSideEffect() const {
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
@@ -1501,7 +1499,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
}
void HloInstruction::RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ absl::Span<const int> ascending_indices) {
if (ascending_indices.empty()) {
return;
}
@@ -1997,7 +1995,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string operands;
- tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
+ absl::Span<HloInstruction* const> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
if (options.compact_operands() &&
slice.size() > kMaxOperandsToShowIfCompact) {
@@ -3310,7 +3308,7 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const {
+absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index f3fd287d88..88cb5d8acf 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -365,7 +365,7 @@ class HloInstruction {
// random numbers from a given distribution.
static std::unique_ptr<HloInstruction> CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ absl::Span<HloInstruction* const> parameters);
// Creates a unary instruction (one operand).
// Precondition: opcode must be a legitimate unary operation.
@@ -392,13 +392,13 @@ class HloInstruction {
// Precondition: opcode must be a legitimate variadic operation.
static std::unique_ptr<HloInstruction> CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
// at a given index)
static std::unique_ptr<HloInstruction> CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
// Creates a convolution op, where rhs is the convolutional filter
@@ -412,7 +412,7 @@ class HloInstruction {
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
@@ -449,7 +449,7 @@ class HloInstruction {
//
// TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
@@ -468,7 +468,7 @@ class HloInstruction {
// be concatenated in the order of 1, 2, 3; another Alltoall will be applied
// within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
static std::unique_ptr<HloInstruction> CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
// Creates a communitation instructions that permutes data cross replicas.
@@ -536,17 +536,15 @@ class HloInstruction {
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specified in
// 'slice_sizes'.
static std::unique_ptr<HloInstruction> CreateDynamicSlice(
const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ HloInstruction* start_indices, absl::Span<const int64> slice_sizes);
// Creates a dynamic update slice instruction, which updates a slice
// of 'operand' with 'update' and 'start_indices'.
@@ -557,7 +555,7 @@ class HloInstruction {
// Creates a concatenate instruction, where the operands are concatenated on
// the provided dimension.
static std::unique_ptr<HloInstruction> CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
@@ -569,7 +567,7 @@ class HloInstruction {
// f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// A more general, multiple-argument version of the above.
@@ -584,9 +582,9 @@ class HloInstruction {
// ...
// TODO(b/112040122): Add support to this in HLO passes and in backends.
static std::unique_ptr<HloInstruction> CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Creates a reduce-window instruction, where the computation (given
@@ -623,7 +621,7 @@ class HloInstruction {
// Creates a broadcast instruction.
static std::unique_ptr<HloInstruction> CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
@@ -653,7 +651,7 @@ class HloInstruction {
// Creates a transpose instruction which permutes the operand dimensions.
static std::unique_ptr<HloInstruction> CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a sort op, with a keys operand, and an optional values operand.
static std::unique_ptr<HloInstruction> CreateSort(
@@ -679,7 +677,7 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -703,37 +701,37 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation);
// Creates a call instruction that applies the given computation on the given
// operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements);
+ absl::Span<HloInstruction* const> elements);
// Creates a reverse instruction, which reverses the order of the elements
// in the specified dimensions.
static std::unique_ptr<HloInstruction> CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a Afterall instruction used for joining or creating new values of
// token type which thread through side-effecting operations. Operands must
// all be tokens, and there must be at least one operand.
static std::unique_ptr<HloInstruction> CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates an AfterAll instruction which creates a token type out of thin air
// (no operands). This is a separate method from CreateAfterAll to facility
@@ -1124,8 +1122,7 @@ class HloInstruction {
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
@@ -1505,7 +1502,7 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
// Delegates to HloGatherInstruction::gather_slice_sizes.
- tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const;
+ absl::Span<const int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
@@ -1531,7 +1528,7 @@ class HloInstruction {
// Removes a list of operands with the given indices in ascending order.
void RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices);
+ absl::Span<const int> ascending_indices);
void AppendComputation(HloComputation* computation) {
called_computations_.push_back(computation);
@@ -1561,8 +1558,7 @@ class HloInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
// TODO(b/80131774): This should be pure virtual.
LOG(FATAL) << "Unimplemented method.";
@@ -1608,7 +1604,7 @@ class HloInstruction {
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Adds a user for this instruction.
void AddUser(HloInstruction* user);
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e1c884d856..6871953755 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -91,8 +91,7 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloBatchNormTrainingInstruction>(
@@ -113,8 +112,7 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
return absl::make_unique<HloBatchNormInferenceInstruction>(
@@ -135,8 +133,7 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
return absl::make_unique<HloBatchNormGradInstruction>(
@@ -144,9 +141,9 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
new_operands[4], epsilon(), feature_index());
}
-HloFftInstruction::HloFftInstruction(
- const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length)
+HloFftInstruction::HloFftInstruction(const Shape& shape,
+ HloInstruction* operand, FftType fft_type,
+ absl::Span<const int64> fft_length)
: HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
fft_length_.assign(fft_length.begin(), fft_length.end());
AppendOperand(operand);
@@ -177,8 +174,7 @@ bool HloFftInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
@@ -232,8 +228,7 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand,
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloSendInstruction>(
@@ -250,8 +245,7 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
std::unique_ptr<HloInstruction>
HloSendDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloSendDoneInstruction>(
@@ -271,8 +265,7 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape,
}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvInstruction>(
@@ -293,8 +286,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
std::unique_ptr<HloInstruction>
HloRecvDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvDoneInstruction>(
@@ -303,7 +295,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCollectiveInstruction::HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloInstruction(opcode, shape), replica_groups_(replica_groups) {
for (auto operand : operands) {
@@ -344,7 +336,7 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
}
HloAllReduceInstruction::HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id)
@@ -392,8 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloAllReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_groups(),
@@ -401,15 +392,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
}
HloAllToAllInstruction::HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
replica_groups) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
replica_groups());
@@ -459,16 +449,15 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloCollectivePermuteInstruction>(
shape, new_operands[0], source_target_pairs());
}
-HloReverseInstruction::HloReverseInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+HloReverseInstruction::HloReverseInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kReverse, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
AppendOperand(operand);
@@ -496,8 +485,7 @@ bool HloReverseInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
@@ -505,7 +493,7 @@ std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
}
HloConcatenateInstruction::HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension)
: HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
for (auto operand : operands) {
@@ -537,16 +525,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConcatenateInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
dimensions(0));
}
HloReduceInstruction::HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
@@ -581,8 +568,7 @@ bool HloReduceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size() % 2, 0);
return absl::make_unique<HloReduceInstruction>(shape, new_operands,
@@ -621,8 +607,7 @@ bool HloSortInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
@@ -632,7 +617,7 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
CHECK_EQ(shape.dimensions().size(), dimensions.size());
@@ -676,8 +661,7 @@ bool HloTransposeInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloTransposeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
@@ -686,7 +670,7 @@ HloTransposeInstruction::CloneWithNewOperandsImpl(
HloBroadcastInstruction::HloBroadcastInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension)
+ absl::Span<const int64> broadcast_dimension)
: HloInstruction(HloOpcode::kBroadcast, shape),
dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
AppendOperand(operand);
@@ -715,17 +699,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloBroadcastInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
dimensions());
}
-HloMapInstruction::HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation)
+HloMapInstruction::HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation)
: HloInstruction(HloOpcode::kMap, shape) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -774,17 +757,16 @@ bool HloMapInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
}
-HloSliceInstruction::HloSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides)
+HloSliceInstruction::HloSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides)
: HloInstruction(HloOpcode::kSlice, shape),
slice_starts_(start_indices.begin(), start_indices.end()),
slice_limits_(limit_indices.begin(), limit_indices.end()),
@@ -835,8 +817,7 @@ bool HloSliceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloSliceInstruction>(
@@ -889,8 +870,7 @@ bool HloConstantInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
}
@@ -947,8 +927,7 @@ bool HloTraceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
}
@@ -966,7 +945,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape,
HloFusionInstruction::HloFusionInstruction(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation)
: HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
for (auto operand : operands) {
@@ -1373,8 +1352,7 @@ bool HloFusionInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloModule* module = context != nullptr ? context->module() : GetModule();
HloComputation* new_fused_computation = nullptr;
@@ -1412,7 +1390,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() {
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
+ absl::Span<HloInstruction* const> parameters)
: HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
for (HloInstruction* param : parameters) {
AppendOperand(param);
@@ -1443,8 +1421,7 @@ bool HloRngInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloRngInstruction>(shape, distribution_,
new_operands);
@@ -1480,8 +1457,7 @@ bool HloParameterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloParameterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
name());
@@ -1516,8 +1492,7 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloGetTupleElementInstruction>(
@@ -1559,8 +1534,7 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloReducePrecisionInstruction>(
@@ -1600,8 +1574,7 @@ bool HloInfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloInfeedInstruction>(
@@ -1646,8 +1619,7 @@ bool HloOutfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloOutfeedInstruction>(
@@ -1717,8 +1689,7 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConvolutionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
@@ -1762,8 +1733,7 @@ bool HloReduceWindowInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReduceWindowInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloReduceWindowInstruction>(
@@ -1811,8 +1781,7 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloSelectAndScatterInstruction>(
@@ -1821,7 +1790,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
}
HloCustomCallInstruction::HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(),
@@ -1887,8 +1856,7 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloCustomCallInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
shape, new_operands, custom_call_target());
@@ -1931,8 +1899,7 @@ bool HloPadInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
@@ -1941,7 +1908,7 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kDynamicSlice, shape),
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
AppendOperand(operand);
@@ -1971,8 +1938,7 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloDynamicSliceInstruction>(
@@ -1982,7 +1948,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
HloGatherInstruction::HloGatherInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
AppendOperand(start_indices);
@@ -2011,10 +1977,9 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> offset_dims,
- tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
- tensorflow::gtl::ArraySlice<int64> start_index_map,
- int64 index_vector_dim) {
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
for (int64 output_window_dim : offset_dims) {
gather_dim_numbers.add_offset_dims(output_window_dim);
@@ -2057,8 +2022,7 @@ bool HloGatherInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloGatherInstruction>(
@@ -2102,9 +2066,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const {
/* static */ ScatterDimensionNumbers
HloScatterInstruction::MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim) {
ScatterDimensionNumbers scatter_dim_numbers;
for (int64 update_window_dim : update_window_dims) {
@@ -2144,8 +2108,7 @@ bool HloScatterInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloScatterInstruction>(
@@ -2177,8 +2140,7 @@ bool HloIotaInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 4fe5144aca..45a648bbe4 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -67,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -82,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -97,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -106,7 +103,7 @@ class HloFftInstruction : public HloInstruction {
public:
explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
FftType fft_type() const { return fft_type_; }
const std::vector<int64>& fft_length() const { return fft_length_; }
@@ -124,8 +121,7 @@ class HloFftInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes FFT type for an FFT instruction.
@@ -174,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -187,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -200,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -213,8 +206,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -227,7 +219,7 @@ class HloCollectiveInstruction : public HloInstruction {
protected:
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
HloInstructionProto ToProto() const override;
@@ -245,7 +237,7 @@ class HloCollectiveInstruction : public HloInstruction {
class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
@@ -274,8 +266,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the barrier config used for CrossReplicaSum.
@@ -290,14 +281,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -324,8 +314,7 @@ class HloCollectivePermuteInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const std::vector<std::pair<int64, int64>> source_target_pairs_;
@@ -334,7 +323,7 @@ class HloCollectivePermuteInstruction : public HloInstruction {
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -350,8 +339,7 @@ class HloReverseInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -359,9 +347,9 @@ class HloReverseInstruction : public HloInstruction {
class HloConcatenateInstruction : public HloInstruction {
public:
- explicit HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- int64 dimension);
+ explicit HloConcatenateInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -379,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -388,10 +375,10 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
- explicit HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reduce_computation);
+ explicit HloReduceInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -403,12 +390,12 @@ class HloReduceInstruction : public HloInstruction {
int64 input_count() const { return operand_count() / 2; }
// Returns the input tensors to be reduced.
- tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
+ absl::Span<HloInstruction* const> inputs() const {
return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
+ absl::Span<HloInstruction* const> init_values() const {
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
@@ -421,8 +408,7 @@ class HloReduceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -450,8 +436,7 @@ class HloSortInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -459,9 +444,8 @@ class HloSortInstruction : public HloInstruction {
class HloTransposeInstruction : public HloInstruction {
public:
- explicit HloTransposeInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -479,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -488,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction {
class HloBroadcastInstruction : public HloInstruction {
public:
- explicit HloBroadcastInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension);
+ explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> broadcast_dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -506,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -515,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction {
class HloMapInstruction : public HloInstruction {
public:
- explicit HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+ explicit HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -535,8 +516,7 @@ class HloMapInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -545,9 +525,9 @@ class HloMapInstruction : public HloInstruction {
class HloSliceInstruction : public HloInstruction {
public:
explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
HloInstructionProto ToProto() const override;
@@ -586,8 +566,7 @@ class HloSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [begin, end) index range for a slice.
@@ -629,8 +608,7 @@ class HloConstantInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
@@ -651,8 +629,7 @@ class HloTraceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
@@ -663,10 +640,9 @@ class HloFusionInstruction : public HloInstruction {
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
HloInstruction* fused_root);
- explicit HloFusionInstruction(
- const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* fusion_computation);
+ explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* fusion_computation);
string ToCategory() const override;
// Returns a serialized representation of this instruction.
@@ -779,8 +755,7 @@ class HloFusionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The type of the fusion. Used by kFusion only.
@@ -789,9 +764,9 @@ class HloFusionInstruction : public HloInstruction {
class HloRngInstruction : public HloInstruction {
public:
- explicit HloRngInstruction(
- const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ explicit HloRngInstruction(const Shape& shape,
+ RandomDistribution distribution,
+ absl::Span<HloInstruction* const> parameters);
// Returns the random distribution for this rng node.
RandomDistribution random_distribution() const { return distribution_; }
// Returns a serialized representation of this instruction.
@@ -808,8 +783,7 @@ class HloRngInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The distribution requested for random number generation.
@@ -834,8 +808,7 @@ class HloParameterInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 parameter_number_ = 0;
@@ -859,8 +832,7 @@ class HloGetTupleElementInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 tuple_index_ = -1;
@@ -888,8 +860,7 @@ class HloReducePrecisionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The bit sizes for a reduce-precision operation.
@@ -926,8 +897,7 @@ class HloInfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the infeed configuration.
@@ -959,8 +929,7 @@ class HloOutfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Shape of outfeed request.
@@ -1001,8 +970,7 @@ class HloConvolutionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
// Describes the dimension numbers used for a convolution.
@@ -1033,8 +1001,7 @@ class HloReduceWindowInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
@@ -1082,17 +1049,16 @@ class HloSelectAndScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- absl::string_view custom_call_target);
+ explicit HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1125,8 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Name of a global symbol to call, only present for kCustomCall.
string custom_call_target_;
@@ -1155,8 +1120,7 @@ class HloPadInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The padding configuration that describes the edge padding and interior
@@ -1166,10 +1130,10 @@ class HloPadInstruction : public HloInstruction {
class HloDynamicSliceInstruction : public HloInstruction {
public:
- explicit HloDynamicSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ explicit HloDynamicSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes);
// Old methods kept for smooth subclassing transition END.
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
@@ -1191,8 +1155,7 @@ class HloDynamicSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [start, start + size) range size for a dynamic slice
@@ -1206,12 +1169,12 @@ class HloGatherInstruction : public HloInstruction {
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const {
+ absl::Span<const int64> gather_slice_sizes() const {
return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
@@ -1221,10 +1184,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> offset_dims,
- tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
- tensorflow::gtl::ArraySlice<int64> start_index_map,
- int64 index_vector_dim);
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim);
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -1234,8 +1196,7 @@ class HloGatherInstruction : public HloInstruction {
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
@@ -1260,9 +1221,9 @@ class HloScatterInstruction : public HloInstruction {
// Creates an instance of ScatterDimensionNumbers.
static ScatterDimensionNumbers MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim);
private:
@@ -1274,8 +1235,7 @@ class HloScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
@@ -1298,8 +1258,7 @@ class HloIotaInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const int64 iota_dimension_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 78167335c8..3a1bc4e328 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -353,7 +353,7 @@ bool IsUsedOutsideSubcomputation(
} // anonymous namespace
HloInstruction* HloModule::OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation) {
auto builder = HloComputation::Builder(outlined_computation_name);
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index cf129b835d..ee5601beec 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -192,7 +192,7 @@ class HloModule {
// order (root of outlined instructions last). TODO(jingyue): takes a set of
// instructions and topologically sorts them.
HloInstruction* OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation);
// Returns a randomly generated uint64.
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index d70328c8a3..d83ee71490 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -193,7 +193,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
}
std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> roots;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
@@ -293,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
}
Status HloModuleGroupUtil::VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
@@ -324,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
index c25ca1aff5..fe11fe1818 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h
@@ -56,7 +56,7 @@ class HloModuleGroupUtil {
// Returns the root instructions of the computations.
std::vector<HloInstruction*> RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Visit state of each instruction during DFS traversal.
enum VisitState {
@@ -93,15 +93,14 @@ class HloModuleGroupUtil {
HloInstruction* root);
// Verifies that the computations are well-formed (e.g., no cycles).
- Status VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ Status VerifyComputations(absl::Span<HloComputation* const> computations);
// Below Reachability utils resemble those in HloComputation, except that
// they can handle instructions across multiple computations.
//
// Creates the reachability map for the instructions in the computations.
StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Updates the reachability of the given instruction, taking the global
// predeccessorss and successors into account.
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 209ad5e58c..80009c7f7e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -44,7 +44,7 @@ class HloModuleTest : public HloTestBase {
// Creates a computation which calls the given zero-parameter computations.
std::unique_ptr<HloComputation> CreateCallComputation(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto builder = HloComputation::Builder("Call");
for (auto computation : computations) {
builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index b93e4f24f6..02201d4542 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -306,7 +306,7 @@ bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
// Creates replica groups from the provided nested array. groups[i] represents
// the replica ids for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
- tensorflow::gtl::ArraySlice<std::vector<int64>> groups) {
+ absl::Span<const std::vector<int64>> groups) {
std::vector<ReplicaGroup> replica_groups;
absl::c_transform(groups, std::back_inserter(replica_groups),
[](const std::vector<int64>& ids) {
@@ -997,10 +997,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(operands).subspan(
+ absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(operands).subspan(
+ absl::Span<HloInstruction* const>(operands).subspan(
operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 01b088a957..961930f0a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
+ absl::Span<const HloInstruction* const> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
@@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap(
}
bool HloReachabilityMap::SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
BitVector& bit_vector = GetBitVector(instruction);
tmp_bit_vector_ = bit_vector;
@@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion(
}
void HloReachabilityMap::FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
}
void HloReachabilityMap::SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 48215d32a8..2c8ebc8e6c 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -42,7 +42,7 @@ class HloReachabilityMap {
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
explicit HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
+ absl::Span<const HloInstruction* const> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
@@ -54,13 +54,12 @@ class HloReachabilityMap {
// vector in the internal graph of this HloReachabilityMap for the given
// instruction and does not transitively update any other part of the
// adjacency matrix.
- bool SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
- const HloInstruction* instruction);
+ bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
+ const HloInstruction* instruction);
// As above, but faster because it does not check if the reachability changed.
void FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction);
// Sets entry so that IsReachable(a, b) will return true
@@ -141,7 +140,7 @@ class HloReachabilityMap {
// Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
void SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector);
// Return the index of the given instruction. The value is used to index into
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 569d2e5d2d..c9629926ea 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -202,8 +202,8 @@ class InstructionList {
// On object construction this ordinal is precisely the instruction's index
// in the list. Later, instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
- void InsertBeforeInstructions(
- Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
+ void InsertBeforeInstructions(Item* to_insert,
+ absl::Span<Item* const> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
<< absl::StrJoin(before_instructions, ", ",
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 7bd8a4a544..66ac1f66fd 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -106,7 +106,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals) {
+ const absl::Span<const Literal* const> literals) {
std::vector<ScopedShapedBuffer> buffers;
for (const Literal* literal : literals) {
CHECK(literal != nullptr);
@@ -118,7 +118,7 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals) {
+ const absl::Span<const std::unique_ptr<Literal>> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
@@ -137,8 +137,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
@@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
@@ -169,8 +169,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
stream.Init();
@@ -190,8 +190,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
@@ -226,8 +226,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
// no arguments.
std::vector<const ShapedBuffer*> argument_buffer_ptrs(
options.num_replicas * options.arguments.size() + 1);
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- argument_buffer_slices;
+ std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
int64 index = 0;
for (int64 i = 0; i < options.num_replicas; ++i) {
int64 device = device_assignment(i, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index cfc519063e..547b5fc1bb 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -104,9 +104,9 @@ class HloRunner {
// Transfers data between the host and device.
StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals);
+ const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals);
+ const absl::Span<const std::unique_ptr<Literal>> literals);
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
const ShapedBuffer& buffer);
@@ -117,24 +117,24 @@ class HloRunner {
// optimization.
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
+ const absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
+ const absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// Executes a given HLO module into a set of replicas, and returns a map
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 930801288a..d49d09d459 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -269,7 +269,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto abs_abs1 = builder.AddInstruction(
HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*>({abs_abs1})));
+ absl::Span<HloInstruction* const>({abs_abs1})));
auto tuple_elm = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 1235259764..de7e6b53d4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -54,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
return HloSharding(flattened_list);
}
-HloSharding HloSharding::Tuple(
- const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings) {
+HloSharding HloSharding::Tuple(const Shape& tuple_shape,
+ absl::Span<const HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
for (auto& sharding : shardings) {
CHECK(!sharding.IsTuple()) << sharding.ToString();
@@ -142,7 +141,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
- tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
+ tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
if (d == device) {
ret_index = {index.begin(), index.end()};
}
@@ -151,8 +150,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
return ret_index;
}
-int64 HloSharding::DeviceForTileIndex(
- tensorflow::gtl::ArraySlice<int64> index) const {
+int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
CHECK(!replicated_);
CHECK(!IsTuple());
if (maximal_) {
@@ -319,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
Status status = Status::OK();
std::set<int64> seen_cores;
tile_assignment_.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, int32 core) {
+ [&](absl::Span<const int64> indices, int32 core) {
// Don't overwrite a bad status, so we report the first error.
if (status.ok()) {
if (core >= num_devices) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index be51c3f55b..01fd9f215d 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -66,7 +66,7 @@ class HloSharding {
// shardings must match the number of leaf nodes in tuple_shape. For
// empty tuples, the shardings array must have one element.
static HloSharding Tuple(const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings);
+ absl::Span<const HloSharding> shardings);
// Creates a new sharding for a tuple type, with a single input sharding
// repeated on each leaf.
@@ -132,7 +132,7 @@ class HloSharding {
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
// REQUIRES: !IsTuple()
- int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
+ int64 DeviceForTileIndex(absl::Span<const int64> index) const;
// Given a device ID, returns the offset within the specified shape of the
// tile that should be executed on the given core. This returns the lower
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 6e9b96488c..34cba6136f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -372,7 +372,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
}
StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
// If we are here, all the instructions being passed had the same sharding
// (or no sharding), by the means of the ShardingMatches() API.
// As such, no kDomain was inserted, and here we are asked to extract the
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 2341f8ada0..80634677e7 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -29,8 +29,8 @@ limitations under the License.
namespace xla {
namespace {
-Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> contents) {
+Array<int64> MakeArray(absl::Span<const int64> dimensions,
+ absl::Span<const int64> contents) {
Array<int64> a(dimensions);
std::copy(contents.begin(), contents.end(), a.begin());
return a;
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index e0c1326177..773fc7d225 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -149,7 +149,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace
void HloValue::SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions) {
+ absl::Span<const HloPosition> positions) {
CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
// The positions must be unique and should not contain the defining position
@@ -222,8 +222,7 @@ string HloValueSet::ToString() const {
}));
}
-bool HloValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
+bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
HloValueSet union_set;
for (const HloValueSet* input : inputs) {
for (const HloValue* value : input->values()) {
@@ -254,7 +253,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
}
bool InstructionValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK_GT(inputs.size(), 0);
for (int i = 1; i < inputs.size(); ++i) {
DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h
index a1151f65e0..6f2ad214f6 100644
--- a/tensorflow/compiler/xla/service/hlo_value.h
+++ b/tensorflow/compiler/xla/service/hlo_value.h
@@ -108,8 +108,7 @@ class HloValue : public BufferValue {
// Sets the positions in the module at which the HloValue appears. Updates
// uses. Should be called once and only once. The defining position should not
// be included in 'positions' as this is set at construction time.
- void SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions);
+ void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
// Returns whether this value is a phi value.
bool is_phi() const { return is_phi_; }
@@ -186,14 +185,14 @@ class HloValueSet {
public:
HloValueSet() = default;
- explicit HloValueSet(tensorflow::gtl::ArraySlice<const HloValue*> values)
+ explicit HloValueSet(absl::Span<const HloValue* const> values)
: values_(values.begin(), values.end()) {
SortAndUniquifyValues();
}
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
// Return the vector of HloValues in the set. Values in the vector are unique
// and stably sorted by value id.
@@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree<HloValueSet> {
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
string ToString() const;
};
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 744cd64bc5..95516dec74 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -699,8 +699,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
instruction->opcode(), instruction->operands()));
}
-string ComputationsToString(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+string ComputationsToString(absl::Span<HloComputation* const> computations) {
return absl::StrJoin(computations, ",",
[](string* s, const HloComputation* computation) {
s->append(computation->name());
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 4d4f681c8a..a4de02a890 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -35,7 +35,6 @@ using ConstantArray = Analysis::ConstantArray;
using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
using absl::StrJoin;
-using tensorflow::gtl::ArraySlice;
} // namespace
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
@@ -186,7 +185,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
+ absl::Span<const int64> output_dims, Shape shape) {
// We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
// `source` is the inner Gather(A, X).
@@ -252,8 +251,7 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
- Array* indices) {
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
@@ -314,7 +312,7 @@ namespace {
// Returns an index into `values` such that the product of the range
// [values.begin()+index, values.end()) is equal to `product`. If there is no
// such index, return -1. All integers in `values` must be positive.
-int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
+int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
int64 current_product = 1;
@@ -343,7 +341,8 @@ struct ReshapePassthroughDimPair {
// The returned vector of pairs is sorted in both the result_dim and the
// operand_dim components.
std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
- ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape) {
+ absl::Span<const int64> operand_shape,
+ absl::Span<const int64> result_shape) {
// A reshape can be seen as an index mapping from output index to input index:
//
// (i_0, ..., i_n) = f(o_0, ..., o_m)
@@ -420,7 +419,7 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// Return true if `dim` is stated as an passthrough operand dim in
// `passthrough_dims`.
bool IsReshapePassthroughOperandDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
return absl::c_any_of(passthrough_dims,
[&](ReshapePassthroughDimPair passthrough_dim_pair) {
return passthrough_dim_pair.operand_dim == dim;
@@ -430,7 +429,8 @@ bool IsReshapePassthroughOperandDim(
// Maps `operand_dim` which must be an passthrough operand dimension to its
// corresponding passthrough result dimension based on `passthrough_dims`.
int64 MapPassthroughOperandDimToResultDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
+ int64 operand_dim) {
auto it = absl::c_find_if(
passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
return passthrough_dim_pair.operand_dim == operand_dim;
@@ -439,9 +439,9 @@ int64 MapPassthroughOperandDimToResultDim(
return it->result_dim;
}
-int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
- ArraySlice<int64> result_shape,
- int64 source_passthrough_dim) {
+int64 FindSourcePositionForPassthroughResultDim(
+ absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
+ int64 source_passthrough_dim) {
VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
<< StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
<< "], " << source_passthrough_dim << ")";
@@ -519,8 +519,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
}
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims) {
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
if (degenerate_dims.empty()) {
return operand;
}
@@ -873,7 +872,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
return nullptr;
}
- ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+ absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
auto is_broadcasted_dim = [&](int64 output_dim) {
return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
};
@@ -896,7 +895,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// The scalar-indexed node "removes" the source dim and "inserts" the output
// dims. We do the opposite here to undo the scalar-indexed operation.
- ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+ absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
for (int64 i = output_dims.size() - 1; i >= 0; --i) {
CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
EraseAt(&simulated_index, output_dims[i]);
@@ -973,8 +972,8 @@ namespace {
// Returns the non-contracting non-batch dimension (as per `contracting_dims`
// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
absl::optional<int64> GetOnlyNonContractingNonBatchDim(
- int64 rank, ArraySlice<int64> contracting_dims,
- ArraySlice<int64> batch_dims) {
+ int64 rank, absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
absl::optional<int64> result;
for (int64 dim = 0; dim < rank; dim++) {
if (!absl::c_linear_search(contracting_dims, dim) &&
@@ -998,7 +997,8 @@ absl::optional<int64> GetOnlyNonContractingNonBatchDim(
// of whatever operand `indexed_array` is to the dot (LHS or RHS).
bool CanFoldDotIntoIndexedArray(
absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
- ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
contracting_dims, batch_dims);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 3fa7d749e1..dcfb725535 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -188,9 +188,7 @@ class IndexedArrayAnalysis {
// `output_dims` are the dimensions in the output array that are being used
// to compute an index into the `indices` array. See the class
// documentation and the overview for more details.
- tensorflow::gtl::ArraySlice<int64> output_dims() const {
- return output_dims_;
- }
+ absl::Span<const int64> output_dims() const { return output_dims_; }
private:
explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
@@ -265,8 +263,7 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
- Array* indices);
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
@@ -303,7 +300,7 @@ class IndexedArrayAnalysis {
// G1 = [Arr[i] for i in I2]
StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+ absl::Span<const int64> output_dims, Shape shape);
// Reshapes a scalar-indexed node to remove the degenerate dimensions in its
// output. The result is always a scalar-indexed node.
@@ -313,8 +310,7 @@ class IndexedArrayAnalysis {
// Reshapes a scalar-indexed node such that the result has the degenerate
// dimensions `degenerate_dims`. The result is always a scalar-indexed node.
StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims);
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims);
StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
const Shape& shape, ScalarIndexedConstantArray* operand);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 4b5285031b..8c907eae0c 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -219,7 +219,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
InstructionFusion::HloInstructionSet
InstructionFusion::ComputeGloballyUnfusible(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
+ absl::Span<HloInstruction* const> post_order) {
// Forbid fusion of producers that:
// a) Need to be duplicated, unless they can be fused into all consumers
// via all paths.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 9802d4cfc1..00b658959a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -123,7 +123,7 @@ class InstructionFusion : public HloPassInterface {
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
HloInstructionSet ComputeGloballyUnfusible(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
+ absl::Span<HloInstruction* const> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 2259dc1083..5dea124768 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
@@ -111,7 +111,7 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return tensorflow::errors::Unimplemented(
"ExecuteAsyncOnStream is not yet supported on Interpreter.");
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index 91d8148d26..588787d445 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_);
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
static int64 ShapeSizeBytes(const Shape& shape);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index db6b910b32..f600b14c6c 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -47,7 +47,7 @@ limitations under the License.
namespace stream_executor {
namespace interpreter {
-using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+using Args = absl::Span<const DeviceMemoryBase>;
class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
public:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index ad350613dd..cc2e862f2e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -99,9 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
}
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) {
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b) {
VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
// No need to use operand_arrays[0], the input array of the
@@ -129,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace(
//
// Emits a sequential loop if launch_dimensions is null.
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
@@ -173,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
}
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
@@ -183,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace(
}
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index e1631a62ae..fb3e4eb97c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -63,25 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace(
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
// where the input and output buffers share the same slice, so we can simply
// modify the input/output buffer without touching any of the other elements.
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b);
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b);
// Given a loop-fusion node whose root is a dynamic-update-slice op whose
// array-to-be-updated and output share the same buffer slice, emits
// (sequential) code for a fusion node that does the dynamic-update-slice in
// place.
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b);
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
// the given launch dimensions.
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 6d637cad6d..b606c993a2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
}
Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
std::vector<llvm::Type*> operand_elemental_ir_types;
for (HloInstruction* operand : operands) {
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index 30471480c4..25ec458160 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
public:
using Generator = llvm_ir::ElementGenerator;
- FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
+ FusedIrEmitter(absl::Span<const llvm_ir::IrArray> parameter_arrays,
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
tiled_parameter_info_(nullptr),
@@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
private:
// Arrays of parameters of fusion instruction
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+ absl::Span<const llvm_ir::IrArray> parameter_arrays_;
const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
ElementalIrEmitter* elemental_emitter_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 36e713d1ac..67f7423121 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
Delinearize(&multidim_, linear, shape, b);
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
llvm::Value* linear, const Shape& shape)
: multidim_(multidim.begin(), multidim.end()),
linear_(linear),
@@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
<< " should have a layout.";
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
const Shape& shape, llvm::IRBuilder<>* b)
: multidim_(multidim.begin(), multidim.end()),
layout_(shape.layout()),
@@ -147,7 +147,7 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
- Index(tensorflow::gtl::ArraySlice<llvm::Value*>(multidim_).subspan(
+ Index(absl::Span<llvm::Value* const>(multidim_).subspan(
common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second),
index_type_)
@@ -184,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
- llvm::IRBuilder<>* builder) const {
+ const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
Index source_index(index_type_, multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
@@ -207,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice(
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
std::vector<llvm::Value*> operand_multidim_index =
Permute(dimension_mapping, multidim());
@@ -256,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
int64 rank = ShapeUtil::Rank(operand_shape);
std::vector<llvm::Value*> source_index(rank);
@@ -321,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
return Index(source_index, linear, operand_shape);
}
-llvm::Value* IrArray::Index::Linearize(
- tensorflow::gtl::ArraySlice<int64> dimensions,
- llvm::IRBuilder<>* builder) const {
+llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
+ llvm::IRBuilder<>* builder) const {
// Each dimension is multiplied by the product of the sizes of all
// earlier dimensions and added to the accumulator logical_linear_index.
CHECK_EQ(size(), dimensions.size());
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index e913c109b3..7629806a36 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -70,7 +70,7 @@ class IrArray {
// Constructs an index from multi-dimensional index "multidim". The linear
// index is set to nullptr.
- explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ explicit Index(absl::Span<llvm::Value* const> multidim,
llvm::Type* index_ty = nullptr)
: multidim_(multidim.begin(), multidim.end()) {
if (size() == 0) {
@@ -99,14 +99,14 @@ class IrArray {
// that it indexes into.
//
// Precondition: "shape" has a layout.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- const Shape& shape, llvm::IRBuilder<>* b);
+ Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
+ llvm::IRBuilder<>* b);
// Constructs an index from both a multi-dimensional index and a linear
// index. "shape" has the same meaning as that in the constructor that takes
// only a linear index.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- llvm::Value* linear, const Shape& shape);
+ Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
+ const Shape& shape);
const std::vector<llvm::Value*>& multidim() const { return multidim_; }
llvm::Value* linear() const { return linear_; }
@@ -145,17 +145,15 @@ class IrArray {
// by starting indices `starts` and stride values `strides`.
//
// Precondition: "this" is an index into a slice whose shape is `shape`.
- Index SourceIndexOfSlice(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
+ Index SourceIndexOfSlice(const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides,
llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a transpose from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfTranspose(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a bitcast from `operand_shape`
// to `shape`, returns the source index.
@@ -164,14 +162,13 @@ class IrArray {
// Given that "this" is the target index of a broadcast from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfBroadcast(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
// returns the index into the sole dimension 0 of the new shape.
- llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
+ llvm::Value* Linearize(absl::Span<const int64> dimensions,
llvm::IRBuilder<>* builder) const;
llvm::Type* GetType() const { return index_type_; }
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index b152cf9275..43fec311f1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -235,7 +235,7 @@ class KernelSupportLibrary {
}));
}
- using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
+ using ArgumentVector = absl::Span<llvm::Value* const>;
// Generates the following control flow structure:
//
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index cb4d1db997..e5fbdbd51b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -28,7 +28,7 @@ namespace {
// Returns the indices of the first elements of all consecutive subarrays of the
// given array. For example:
// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
std::vector<size_t> is = {0};
for (size_t i = 1; i < xs.size(); ++i) {
if (1 != xs[i] - xs[i - 1]) {
@@ -40,8 +40,7 @@ std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
// Merges the sequences of dimensions of the given shape which start at the
// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
+Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
std::vector<int64> dimensions;
for (size_t i = 1; i <= segs.size(); ++i) {
dimensions.push_back(std::accumulate(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index 8bd06c42c3..5ea05b3188 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex(
// for 021 transpose.
class TiledParameterInfo {
public:
- TiledParameterInfo(tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers,
+ TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
llvm::Value* y, llvm::Value* x)
: param_buffers_(param_buffers), y_(y), x_(x) {}
@@ -67,7 +67,7 @@ class TiledParameterInfo {
private:
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
// if the parameter is not tiled.
- tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers_;
+ absl::Span<llvm::Value* const> param_buffers_;
// The y coordinate within a tile.
llvm::Value* y_;
// The x coordinate within a tile.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 9f3329e7f0..219a9f221f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -241,7 +241,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
}
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ const Shape& shape, absl::Span<const int64> dimensions,
absl::string_view suffix) {
llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index 0a406bd90b..2be7bbd0de 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -242,7 +242,7 @@ class ForLoopNest {
// size equals the rank of shape and there is a null for each
// dimension that is not in "dimensions".
IrArray::Index AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ const Shape& shape, absl::Span<const int64> dimensions,
absl::string_view suffix);
// Emits a series of nested loops for iterating over an operand array. Loops
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index f0db2a3761..1a53c026be 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) {
return AsString(buffer_string);
}
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b) {
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index dde50e19d1..61b029eb08 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -59,7 +59,7 @@ llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
}
template <typename T>
-llvm::ArrayRef<T> AsArrayRef(const tensorflow::gtl::ArraySlice<T>& slice) {
+llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
return llvm::ArrayRef<T>(slice.data(), slice.size());
}
@@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name);
// intrinsics (for example, "minnum") must include a type in overloaded_types
// for each overloaded type. Typically, overloaded intrinsics have only a single
// overloaded type.
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b);
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b);
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 1553b4fc91..0dc120e0b0 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
}
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
+ absl::Span<const IrArray> target_arrays,
llvm::IRBuilder<>* b)
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
target_element_generator,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index 57d9d8bbc6..a537c00066 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -53,8 +53,7 @@ class LoopEmitter {
// This is used for multi-output fusion. target_element_generator must
// produce an LLVM struct with N elements.
LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
- llvm::IRBuilder<>* b);
+ absl::Span<const IrArray> target_arrays, llvm::IRBuilder<>* b);
LoopEmitter(const LoopEmitter&) = delete;
LoopEmitter& operator=(const LoopEmitter&) = delete;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 11ed6ee59f..7d49b8d6c2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
}
}
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
auto* store = b->CreateStore(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index cf6bf5d0b1..cee211d66f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -65,8 +65,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 768105d9e1..0d0fb7946a 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -141,7 +141,7 @@ ExecutionOptions CreateExecutionOptions(
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options) {
const HloModuleProto& proto = computation.proto();
TF_RET_CHECK(proto.has_program_shape());
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 8f707ea904..acc8c6d2e0 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -48,7 +48,7 @@ class LocalService : public Service {
// compiler is responsible for freeing any memory it allocates this way.
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options);
// Returns the device ordinal that corresponds to the given replica number.
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 4166ef5baf..b9ec31c497 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() {
void MultiOutputFusion::UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip) {
for (auto instr : instrs_to_update) {
if (skip != nullptr && skip(instr)) {
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 4c8cb7d379..d2c52651c4 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -92,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface {
// Update the reachability map after fusing instr1 and instr2.
void UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip = nullptr);
// Hook for multi-output fusion along producer-consumer edges.
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 2077b57c05..2f4b2667c4 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -26,7 +26,6 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::ArraySlice;
// Transposes the given scatter_indices such that the index_vector_dim becomes
// the most-minor dimension.
@@ -87,7 +86,7 @@ static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
// major dimensions and all the window dimensions appear in the minor
// dimensions.
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
- HloInstruction* updates, ArraySlice<int64> update_window_dims) {
+ HloInstruction* updates, absl::Span<const int64> update_window_dims) {
std::vector<int64> permutation;
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
permutation.reserve(updates_rank);
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index e10c1d9927..f0e2566a3f 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -62,10 +62,9 @@ using absl::StrCat;
using absl::StrFormat;
// Records the arguments used to invoke a computation in an HloSnapshot proto.
-Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::Stream* stream, TransferManager* transfer_manager,
- HloSnapshot* module) {
+Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
+ se::Stream* stream, TransferManager* transfer_manager,
+ HloSnapshot* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
@@ -207,8 +206,8 @@ Status Service::ValidateResultShape(const Shape& client_shape,
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
Service::ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors) {
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors) {
CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
replicated_arguments.resize(options_.number_of_replicas());
@@ -242,7 +241,7 @@ Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
@@ -299,7 +298,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
@@ -367,12 +366,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile) {
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile) {
// Streams where the computation are launched, so we can wait on the streams
// to complete.
std::vector<StreamPool::Ptr> streams;
@@ -511,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult(
StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile) {
// Set up streams.
std::vector<StreamPool::Ptr> streams;
@@ -555,8 +551,7 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
// TODO(b/69985541): Support profiling also on this path.
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- replicated_arguments;
+ std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
for (const auto& arg : arguments) {
replicated_arguments.push_back(arg);
}
@@ -595,7 +590,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments) {
+ absl::Span<const GlobalDataHandle* const> arguments) {
// Resolve the allocations for the arguments of the computation, and create
// a vector of device memory offsets for the arguments from the allocations.
// In the case of partitioned computations, assume all arguments go on the
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 47d196fb2a..173300d8b6 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -176,7 +176,7 @@ class Service : public ServiceInterface {
// class.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options);
// Picks a parallel response and fills the result.
@@ -191,7 +191,7 @@ class Service : public ServiceInterface {
// Prepare the arguments for executing parallel.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+ absl::Span<const GlobalDataHandle* const> arguments);
protected:
friend class LocalExecutable;
@@ -207,14 +207,14 @@ class Service : public ServiceInterface {
// the corresponding replica.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors);
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options);
// Builds an Executable for the given parameters.
@@ -242,21 +242,17 @@ class Service : public ServiceInterface {
// ExecutionProfile object which will be filled in with profile data.
StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile);
// Runs the given executables with the given arguments and register the result
// from each executable in the allocation tracker. The handles of the result
// from the tracker are returned.
StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend,
- tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile);
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile);
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 45427bba25..2611749862 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -45,7 +45,7 @@ using absl::StrFormat;
using absl::StrJoin;
// Returns true if no element is present in slice more than once.
-bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+bool AllUnique(absl::Span<const int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
@@ -57,11 +57,10 @@ Status ExpectArray(const Shape& shape, absl::string_view op_type) {
return Status::OK();
}
-Status VerifyReducerShape(
- const ProgramShape& reducer_shape,
- tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
- int64 inputs) {
+Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ absl::Span<const Shape* const> init_value_shapes,
+ absl::Span<const PrimitiveType> input_element_types,
+ int64 inputs) {
if (reducer_shape.parameters_size() != inputs * 2) {
return InvalidArgument(
"Reduction function must take %d parameters, but "
@@ -335,8 +334,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const int64 dimension) {
+ absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
if (arg_shapes.empty()) {
return InvalidArgument("Concatenate expects at least one argument.");
}
@@ -394,7 +392,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ absl::Span<const Shape* const> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
return InvalidArgument(
@@ -550,22 +548,22 @@ Status ValidateDotDimensionNumbers(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
// Check that dimension numbers are in range.
- auto dims_in_range =
- [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_in_range = [](const int64 rank,
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
in_range) &&
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
};
- tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
+ absl::Span<const int64> lhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
+ absl::Span<const int64> rhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
+ absl::Span<const int64> lhs_batch_dimensions =
AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
+ absl::Span<const int64> rhs_batch_dimensions =
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
@@ -577,8 +575,8 @@ Status ValidateDotDimensionNumbers(
}
// Check that dimension numbers are unique.
- auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_unique = [](absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
tensorflow::gtl::FlatSet<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
@@ -748,7 +746,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
// the user to provide an explicit broadcast dimension in this case.
@@ -849,7 +847,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
@@ -906,7 +904,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
VLOG(2) << StrFormat(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
@@ -1005,8 +1003,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
std::vector<const Shape*> operand_shapes;
operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
@@ -1016,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
@@ -1053,9 +1049,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions) {
if (arg_shapes.empty()) {
return InvalidArgument("Map expects at least one argument.");
}
@@ -1711,7 +1706,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
const Shape& in, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
const int64 fft_rank = fft_length.size();
if (fft_rank < 1 || fft_rank > 3) {
return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
@@ -1792,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
ExpectArray(*operand_shape, "operand of cross replica sum"));
@@ -1835,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
// An Alltoall HLO instruction receives N operands (with the same shape) and
// returns a tuple that contains N array shapes.
TF_RET_CHECK(!operand_shapes.empty());
@@ -1859,8 +1854,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
if (arg_shapes.empty()) {
return InvalidArgument("Reduce must have at least 2 arguments, has 0");
@@ -1998,9 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ const Shape& arg, absl::Span<const int64> starts,
+ absl::Span<const int64> limits, absl::Span<const int64> strides) {
auto error = [&](const string& message) {
return InvalidArgument(
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
@@ -2062,7 +2056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
@@ -2189,7 +2183,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
- const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand_shape, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
@@ -2315,7 +2309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const Shape& operand, absl::Span<const int64> broadcast_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
@@ -2333,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ const Shape& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
@@ -2366,7 +2360,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2471,8 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
// The applied function's arity equals the number of arguments.
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
@@ -2505,8 +2498,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
static Status ValidateGatherDimensionNumbers(
- const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices_shape,
+ const Shape& input_shape, absl::Span<const int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
@@ -2599,7 +2591,7 @@ static Status ValidateGatherDimensionNumbers(
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
@@ -2709,8 +2701,7 @@ static Status ValidateGatherDimensionNumbers(
namespace {
Status ValidateScatterDimensionNumbers(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 235b1a4cf3..072ada2d8f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -55,7 +55,7 @@ class ShapeInference {
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
const HloInstruction* lhs,
const HloInstruction* rhs);
@@ -73,18 +73,15 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
// Infers the shape produced by applying the given mapping computation shape
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.
@@ -116,14 +113,13 @@ class ShapeInference {
int64 feature_group_count = 1);
// Infers the shape produced by the given FFT type on the given operand.
- static StatusOr<Shape> InferFftShape(
- const Shape& in, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
+ absl::Span<const int64> fft_length);
// Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers final shape of an Alltoall operation that is created by the xla
// builder.
@@ -134,7 +130,7 @@ class ShapeInference {
// Infers the shape of an HLO all-to-all instruction.
static StatusOr<Shape> InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers the shape of a collective permute operation.
static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
@@ -146,8 +142,8 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply);
// Infers the shape produced by applying the given computation to the operand
@@ -165,24 +161,23 @@ class ShapeInference {
// Infers the shape produced by a reverse operation that reverses the order
// of the elements in the given dimensions.
- static StatusOr<Shape> InferReverseShape(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by a slice operation spanning from the starts to
// the limits in the original shape's dimensions.
//
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
- static StatusOr<Shape> InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides);
+ static StatusOr<Shape> InferSliceShape(const Shape& arg,
+ absl::Span<const int64> starts,
+ absl::Span<const int64> limits,
+ absl::Span<const int64> strides);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
static StatusOr<Shape> InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Infers the shape produced by a dynamic update slice operation based
// on the shape of operand and update.
@@ -213,30 +208,30 @@ class ShapeInference {
// Infers the shape produced by a broadcast operation.
static StatusOr<Shape> InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ const Shape& operand, absl::Span<const int64> broadcast_sizes);
// Infers the shape produced by a reshape operation from the element type of
// its operand and the new dimension sizes specified.
- static StatusOr<Shape> InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ static StatusOr<Shape> InferReshapeShape(const Shape& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Infers the shape produced by a transpose operation from the element type of
// its operand and its dimensions field.
static StatusOr<Shape> InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+ const Shape& operand, absl::Span<const int64> dimensions);
// Helper that infers the shape produced by performing a concatenate operation
// with the given operand shapes.
static StatusOr<Shape> InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ absl::Span<const Shape* const> arg_shapes, int64 dimension);
// Infers the shape produced by a kAfterAll. Trivially this shape is always a
// TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
// and checking operand shapes. This method verifies that the operand shapes
// are all TOKENs.
static StatusOr<Shape> InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+ absl::Span<const Shape* const> arg_shapes);
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
@@ -266,8 +261,7 @@ class ShapeInference {
// Helper that validates the given arg_shapes are compatible with the shape of
// the to_apply parameters, and returns the to_apply result shape.
static StatusOr<Shape> InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
@@ -281,7 +275,7 @@ class ShapeInference {
static StatusOr<Shape> InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
@@ -299,7 +293,7 @@ class ShapeInference {
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
@@ -327,7 +321,7 @@ class ShapeInference {
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
};
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 4ed8fc6b86..5dbe5a1611 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -28,7 +28,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
// Helper that runs reduce shape inference with the input 'arg' and given
// dimensions to reduce, and checks the inferred shape is as expected. The
// element type here is hard-coded to F32.
- void ExpectInferredReduceShape(
- const Shape& expected_inferred_shape, const Shape& arg,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
+ const Shape& arg,
+ absl::Span<const int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
{&arg, &f32_}, dimensions_to_reduce, to_apply);
@@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
- const tensorflow::gtl::ArraySlice<int64>& bcast) {
+ const absl::Span<const int64>& bcast) {
return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
bcast);
};
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index f77690a462..0c393c53a1 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -130,7 +130,7 @@ class TransferManager {
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executor) = 0;
+ absl::Span<se::StreamExecutor* const> executor) = 0;
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
@@ -211,8 +211,7 @@ class TransferManager {
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
private:
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index cf00ca102b..6fed7c76d0 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
}
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
points_to_set.AddPointedToBuffer(
logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 10d382e8ab..a32d1f9026 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Checks that the given points-to set contains exactly (unordered) the given
// LogicalBuffers.
- void ExpectHasBuffers(
- const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
+ void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
+ absl::Span<const LogicalBuffer* const> buffers) {
std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
}
@@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// top-level buffers of the given instructions.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
PointsToSet::BufferList buffers;
for (auto instruction : instructions) {
buffers.push_back(GetBuffer(instruction, /*index=*/{}));
@@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Overload which takes a set instead of a vector.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferSet& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
ExpectHasTopLevelBuffers(
PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
instructions);
@@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// aliases which are exactly (unordered) the given instruction/index pairs.
void ExpectHasBufferAliases(
const HloInstruction* instruction, const ShapeIndex& index,
- tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>>
- expected) {
+ absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
const LogicalBuffer* buffer =
points_to_analysis_->GetBufferDefinedAt(instruction, index)
.ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc
index 4a530bb0b2..9ba01ef7a6 100644
--- a/tensorflow/compiler/xla/service/tuple_util.cc
+++ b/tensorflow/compiler/xla/service/tuple_util.cc
@@ -40,7 +40,7 @@ namespace xla {
/*static*/ HloInstruction* TupleUtil::AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values) {
+ absl::Span<HloInstruction* const> trailing_values) {
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
HloComputation* computation = input_tuple->parent();
diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h
index e5ff9aaa83..bc5aac09f2 100644
--- a/tensorflow/compiler/xla/service/tuple_util.h
+++ b/tensorflow/compiler/xla/service/tuple_util.h
@@ -38,7 +38,7 @@ class TupleUtil {
// `input_tuple`.
static HloInstruction* AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values);
+ absl::Span<HloInstruction* const> trailing_values);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index 7e4ac92a7c..c3c2603c7e 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -211,8 +211,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() ==
- tensorflow::gtl::ArraySlice<bool>{false}) {
+ if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index e8f76ff745..f90ac91f9d 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -94,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
WhileUtil::MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
CHECK(ShapeUtil::IsTuple(while_instr->shape()));
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h
index e67636d80f..b1c4486887 100644
--- a/tensorflow/compiler/xla/service/while_util.h
+++ b/tensorflow/compiler/xla/service/while_util.h
@@ -55,7 +55,7 @@ class WhileUtil {
// that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
+ absl::Span<HloInstruction* const> instructions);
using LoopStateTy = std::vector<HloInstruction*>;
using LoopBodyGeneratorTy = std::function<StatusOr<LoopStateTy>(