diff options
author | 2018-08-20 16:07:12 -0700 | |
---|---|---|
committer | 2018-08-20 16:18:21 -0700 | |
commit | 65b9ed5a83319830db02504d4c69e98bd07665b6 (patch) | |
tree | 7606d62d577790774274bcb9dbb09aea5b42a620 | |
parent | e687764a94abc17866213d505d1dbe5e4873e1b9 (diff) |
[XLA] Switch to absl versions of the c_foo functions.
PiperOrigin-RevId: 209502513
40 files changed, 248 insertions, 288 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fdf13bb18c..e36429f62d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -173,6 +173,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index ad3fcee05b..0ecf26e772 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -211,6 +211,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 4dffab3c2c..e65dd5cbb4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -469,8 +470,8 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -622,8 +623,8 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); @@ -749,8 +750,8 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); @@ -1540,8 +1541,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands, HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -1945,8 +1946,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); std::vector<const Shape*> slice_shape_ptrs; - c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a65bdebf51..12ec38736e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -175,6 +175,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -237,6 +238,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -311,6 +313,7 @@ cc_library( "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1142,6 +1145,7 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1181,6 +1185,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1231,6 +1236,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1245,6 +1251,7 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1289,6 +1296,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1323,8 +1331,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1582,6 +1589,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1744,6 +1752,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2565,6 +2574,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2927,6 +2937,7 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2940,6 +2951,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2955,6 +2967,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2982,6 +2995,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -3036,6 +3050,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -3069,6 +3084,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f7812d9661..2c539eb99a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -1752,8 +1753,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } auto is_unstrided_slice = [](const HloInstruction* hlo) { - return c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { @@ -1930,7 +1931,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // This should make fusion easier or use less memory bandwidth in the unfused // case. if (arg->opcode() == HloOpcode::kConcatenate && - c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + absl::c_linear_search(reduce->dimensions(), + arg->concatenate_dimension())) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 2099916509..b226e7ecb0 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -84,10 +85,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) { bool changed = false; std::vector<HloInstruction*> dot_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), - [](HloInstruction* instr) { - return instr->opcode() == HloOpcode::kDot; - }); + absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); } for (HloInstruction* dot_instr : dot_instrs) { TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index fe1ef78533..9cad674934 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -893,6 +893,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 3274be8d9d..962ea69c09 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "absl/algorithm/container.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -422,8 +423,8 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support, std::vector<llvm::Value*> TileVariable::Get() const { std::vector<llvm::Value*> result; - c_transform(storage_, std::back_inserter(result), - [&](VectorVariable vect_var) { return vect_var.Get(); }); + absl::c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); return result; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 891ae42141..4b19aa5df9 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" @@ -1672,7 +1673,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); @@ -1686,7 +1687,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( { std::vector<llvm::Value*> gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.offset_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 9370c88710..d889fd8e88 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -230,7 +231,7 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue( accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); for (int64 i = 0; i < slice_sizes.size(); i++) { - if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { accumulator_state_shape_dims.push_back(slice_sizes[i]); } } @@ -251,7 +252,7 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims( int64 batch_idx_counter = 0; int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_offset_dim = c_binary_search(offset_dims, i); + bool is_offset_dim = absl::c_binary_search(offset_dims, i); if (is_offset_dim) { permutation.push_back(offset_idx_counter++); } else { @@ -373,8 +374,8 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) { std::vector<HloInstruction*> gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - is_nontrivial_gather); + absl::c_copy_if(computation->instructions(), + std::back_inserter(gather_instrs), is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 8ef72850dc..fd1e34a547 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -180,6 +180,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:support", ], @@ -466,6 +467,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -513,6 +515,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3cd30b754c..9b86e5315b 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -64,10 +65,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (absl::c_all_of( + instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -223,7 +225,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput); @@ -241,11 +243,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (c_any_of(fusion->fused_instructions(), - [](const HloInstruction* instruction) { - return instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction); - })) { + if (absl::c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { VLOG(3) << "Not merging " << fusion->name() << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6675dbd3f9..7111b53944 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -518,7 +519,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { - DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); rhs_index[i] = lhs_index[i]; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1e81cbde35..71c30e19a2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "absl/algorithm/container.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -314,13 +315,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, }; // Check the size of input tensors - if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { return i64_ty; } // Check the size of the internal result tensors if (unnested_hlo->opcode() == HloOpcode::kFusion) { - if (!c_all_of( + if (!absl::c_all_of( unnested_hlo->fused_instructions_computation()->instructions(), hlo_shape_in_range)) { return i64_ty; @@ -1738,7 +1739,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = - c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { + absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() .GetUniqueTopLevelSlice(tuple_element) .ok(); @@ -2322,10 +2323,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. std::vector<const BufferAllocation*> non_constant_buffers; - c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), - [](const BufferAllocation* allocation) { - return !allocation->is_constant(); - }); + absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { @@ -2582,7 +2583,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( // MemzeroThunk. ArraySlice<uint8> literal_bytes( reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); - if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { return { MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)}; } @@ -3105,7 +3106,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize); } const int64 num_tiles = - c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>()); + absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>()); LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); llvm::Type* index_ty = diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c62bae0628..34a479b289 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -131,7 +132,7 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) { max_rank_layout = ¶m->shape().layout(); } } - return c_all_of(params, [&](HloInstruction* param) { + return absl::c_all_of(params, [&](HloInstruction* param) { return (ShapeUtil::Rank(param->shape()) < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); @@ -248,7 +249,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { @@ -268,7 +269,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 441288da1a..db853360f1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -23,6 +23,7 @@ limitations under the License. #include <set> #include <sstream> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -901,9 +902,9 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { HloInstruction* HloComputation::GetInstructionWithName( tensorflow::StringPiece name) { auto instructions_in_computation = instructions(); - auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { - return instr->name() == name; - }); + auto it = absl::c_find_if( + instructions_in_computation, + [&](HloInstruction* instr) { return instr->name() == name; }); return it == instructions_in_computation.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 858992a326..83adaddba4 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -149,13 +150,13 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands, CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); - CHECK(c_all_of(operands, [&](HloInstruction* instr) { + CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) { return instr->parent() == computation; })); std::vector<const Shape*> operand_shapes; - c_transform(operands, std::back_inserter(operand_shapes), - [](HloInstruction* instr) { return &instr->shape(); }); + absl::c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( operand_shapes, dimension)); @@ -228,7 +229,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand, const Shape& operand_shape = operand->shape(); new_shape_dims.reserve(n + operand_shape.dimensions_size()); new_shape_dims.insert(new_shape_dims.begin(), n, 1); - c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); return MakeReshapeHlo(new_shape_dims, operand); } @@ -240,7 +241,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims( std::vector<int64> expanded_shape_dim_bounds; expanded_shape_dim_bounds.reserve(expanded_dims.size() + operand->shape().dimensions_size() - 1); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); std::copy(operand->shape().dimensions().begin() + 1, operand->shape().dimensions().end(), std::back_inserter(expanded_shape_dim_bounds)); @@ -251,7 +252,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims( StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, ArraySlice<int64> dims_to_elide) { - CHECK(c_is_sorted(dims_to_elide)); + CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); // First accumulate in reverse @@ -268,7 +269,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, } } - c_reverse(new_shape_dim_bounds); + absl::c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); return MakeReshapeHlo(output_shape, operand); @@ -276,7 +277,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, StatusOr<HloInstruction*> InsertDegenerateDims( HloInstruction* operand, ArraySlice<int64> dims_to_insert) { - CHECK(c_is_sorted(dims_to_insert)); + CHECK(absl::c_is_sorted(dims_to_insert)); const Shape& operand_shape = operand->shape(); int64 output_shape_rank = diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 36d6a2eed6..0455c7f41a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -564,7 +565,8 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( std::vector<int64> index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i); + bool is_output_batch_dim = + !absl::c_binary_search(dim_numbers.offset_dims(), i); index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } @@ -581,10 +583,11 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( std::vector<int64> index_count(output_rank, 1); int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { - bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i); + bool is_output_window_dim = + absl::c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.collapsed_slice_dims(), - slice_sizes_idx)) { + while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { slice_sizes_idx++; } index_count[i] = slice_sizes[slice_sizes_idx++]; @@ -610,13 +613,13 @@ class OutputBatchIndexToInputIndex { : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { output_dim_is_batch_dims_.push_back( - !c_binary_search(dim_numbers_.offset_dims(), i)); + !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = std::distance(dim_numbers_.start_index_map().begin(), - c_find(dim_numbers_.start_index_map(), i)); + absl::c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); @@ -736,7 +739,7 @@ class OutputOffsetIndexToInputIndex { std::vector<int64> window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.offset_dims(), i)) { + if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -745,7 +748,7 @@ class OutputOffsetIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index f62e6b74b1..a7c5d71da0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -1825,7 +1826,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_scatter_dim = - !c_binary_search(dim_numbers.update_window_dims(), i); + !absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_scatter_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1844,7 +1845,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_window_dim = - c_binary_search(dim_numbers.update_window_dims(), i); + absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_window_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1871,7 +1872,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { update_dim_is_scatter_dims_.push_back( - !c_binary_search(dim_numbers_.update_window_dims(), i)); + !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { @@ -2001,7 +2002,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> window_index_to_update_index; int64 update_index_count = 0; for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.update_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { window_index_to_update_index.push_back(update_index_count++); } else { update_index_count++; @@ -2010,7 +2011,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_dim_value_to_update_index_.push_back(-1); } else { input_dim_value_to_update_index_.push_back( diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index c3ccbf0f0c..f554401787 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -67,11 +68,11 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData( // The profile indices were computed deterministically in // HloProfileIndexMap::HloProfileIndexMap. - c_sort(computation_and_profile_idx_list, - [](const std::pair<const HloComputation*, int64>& left, - const std::pair<const HloComputation*, int64>& right) { - return left.second < right.second; - }); + absl::c_sort(computation_and_profile_idx_list, + [](const std::pair<const HloComputation*, int64>& left, + const std::pair<const HloComputation*, int64>& right) { + return left.second < right.second; + }); for (const auto& pair : computation_and_profile_idx_list) { CHECK_LT(pair.second, profile_counters_size); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 57e75cf931..2b81213509 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_set> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -379,7 +380,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << "DynamicSlice instruction should have 2 operands but sees " << proto.operand_ids_size(); std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size()); - c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), slice_sizes); break; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4fdf4360e6..0751aacdd6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,7 @@ limitations under the License. #include <deque> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -1973,7 +1974,7 @@ HloGatherInstruction::HloGatherInstruction( AppendOperand(start_indices); gather_dimension_numbers_ = MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); + absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 55ff073d3f..76f8236048 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_set> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -538,9 +539,9 @@ uint64 HloModule::RandomNew64() const { HloComputation* HloModule::GetComputationWithName( tensorflow::StringPiece name) { auto computations_in_module = computations(); - auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { - return computation->name() == name; - }); + auto it = absl::c_find_if( + computations_in_module, + [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ab57a8b07f..e48c9d2c41 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -635,12 +636,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } std::vector<ReplicaGroup> replica_groups; if (tmp_groups) { - c_transform(*tmp_groups, std::back_inserter(replica_groups), - [](const std::vector<int64>& ids) { - ReplicaGroup group; - *group.mutable_replica_ids() = {ids.begin(), ids.end()}; - return group; - }); + absl::c_transform( + *tmp_groups, std::back_inserter(replica_groups), + [](const std::vector<int64>& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); } instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( shape, operands, replica_groups, barrier ? *barrier : "")); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8d17c03afc..39dff567d4 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" @@ -290,13 +291,13 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather( int64 source_dim = dim_numbers.start_index_map(0); std::vector<int64> output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.offset_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) { - if (c_linear_search(indexed->output_dims(), source_dim)) { + if (absl::c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -314,7 +315,7 @@ namespace { // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) { - DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; int64 i; @@ -388,26 +389,26 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( result_subarray_size *= result_shape[result_dim]; } - c_reverse(result); + absl::c_reverse(result); if (VLOG_IS_ON(3)) { std::vector<string> result_strings; - c_transform(result, std::back_inserter(result_strings), - [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat(value.result_dim, "->", - value.operand_dim); - }); + absl::c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat( + value.result_dim, "->", value.operand_dim); + }); VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" << Join(result_shape, ",") << "] passthrough indices are [" << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; } - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.result_dim < rhs.result_dim; })); - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.operand_dim < rhs.operand_dim; })); @@ -419,20 +420,20 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( // `passthrough_dims`. bool IsReshapePassthroughOperandDim( ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) { - return c_any_of(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == dim; - }); + return absl::c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); } // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) { - auto it = c_find_if(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == operand_dim; - }); + auto it = absl::c_find_if( + passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); CHECK(it != passthrough_dims.end()); return it->result_dim; } @@ -453,8 +454,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape, Shape StripDegenerateDimensions(const Shape& shape) { DimensionVector new_dims; - c_copy_if(shape.dimensions(), std::back_inserter(new_dims), - [](int64 dim) { return dim != 1; }); + absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); return ShapeUtil::MakeShape(shape.element_type(), new_dims); } }; // namespace @@ -552,8 +553,8 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims( }(); DimensionVector new_result_shape_dims; - c_copy(operand->shape().dimensions(), - std::back_inserter(new_result_shape_dims)); + absl::c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); for (int64 degenerate_dim : degenerate_dims) { InsertAt(&new_result_shape_dims, degenerate_dim, 1); } @@ -694,8 +695,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( operand_dim); }; - if (!c_all_of(scalar_indexed->output_dims(), - is_reshape_passthrough_operand_dim)) { + if (!absl::c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { VLOG(3) << "Not all output dims are passthrough dims " << ToString(scalar_indexed); return nullptr; @@ -763,8 +764,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, - std::multiplies<int64>()), + CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies<int64>()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); CHECK(IsReshapePassthroughOperandDim( @@ -780,9 +781,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( }; std::vector<int64> output_dims_for_new_scalar_indexed_node; - c_transform(scalar_indexed->output_dims(), - std::back_inserter(output_dims_for_new_scalar_indexed_node), - map_passthrough_operand_dim_to_result_dim); + absl::c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, TakeOwnership(scalar_indexed->literal().Reshape( @@ -873,11 +874,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { - return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; // All of the output dims must be "broadcasted" dims for the other operand. - if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + if (!absl::c_all_of(scalar_indexed_const->output_dims(), + is_broadcasted_dim)) { return nullptr; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f33942d679..2fd2214806 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -497,7 +498,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return c_any_of( + return absl::c_any_of( consumer->operands(), [&](const HloInstruction* consumer_operand) { // The fusion algorithm traverses the HLO graph in reverse post order. // Thus `cosumers` is visited before its operands (including diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index cdd3daf73b..ce2d6678a5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", ], ) diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 28ca793e3e..cbfd2e7012 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -19,6 +19,7 @@ limitations under the License. #include <map> #include <vector> +#include "absl/algorithm/container.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" @@ -81,7 +82,7 @@ class IrArray { } } CHECK_NE(index_type_, nullptr); - CHECK(c_all_of(multidim, [&](llvm::Value* v) { + CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { return index_type_ == v->getType(); })); } diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index ca86c5d13e..4df746fca9 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include <algorithm> + +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -374,7 +376,7 @@ StatusOr<bool> TryReshapeMoveOnCandidates( removed = false; for (auto operand : nontrivial_operands) { - if (c_any_of(operand->users(), [&](HloInstruction* user) { + if (absl::c_any_of(operand->users(), [&](HloInstruction* user) { return !reshape_candidates->count(user); })) { for (auto* user : operand->users()) { diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 45ca731153..338f0c09e9 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -92,7 +93,7 @@ static StatusOr<HloInstruction*> PermuteScatterAndWindowDims( permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { - bool is_scatter_dim = !c_binary_search(update_window_dims, i); + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); if (is_scatter_dim) { permutation.push_back(i); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index cc1ec1704e..ec6aa6df55 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -21,6 +21,7 @@ limitations under the License. #include <set> #include <string> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -2494,13 +2495,13 @@ static Status ValidateGatherDimensionNumbers( const Shape& input_shape, tensorflow::gtl::ArraySlice<int64> start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.offset_dims())) { + if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", Join(dim_numbers.offset_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.offset_dims()) != + if (absl::c_adjacent_find(dim_numbers.offset_dims()) != dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", @@ -2546,9 +2547,10 @@ static Status ValidateGatherDimensionNumbers( dim_numbers.start_index_map().begin(), dim_numbers.start_index_map().end()); - c_sort(sorted_start_index_map); + absl::c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) { + if (absl::c_adjacent_find(sorted_start_index_map) != + sorted_start_index_map.end()) { return InvalidArgument( "Repeated dimensions are not allowed in start_index_map; " "got: %s.", @@ -2564,13 +2566,13 @@ static Status ValidateGatherDimensionNumbers( } } - if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) { + if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( "collapsed_slice_dims in gather op must be sorted; got: %s", Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " @@ -2613,8 +2615,8 @@ static Status ValidateGatherDimensionNumbers( std::vector<int64> expanded_start_indices_shape; expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); - c_copy(start_indices_shape.dimensions(), - std::back_inserter(expanded_start_indices_shape)); + absl::c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { expanded_start_indices_shape.push_back(1); @@ -2670,10 +2672,11 @@ static Status ValidateGatherDimensionNumbers( output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; - bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i); + bool is_window_index = + absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(), - offset_dims_seen)) { + while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { offset_dims_seen++; } current_bound = slice_sizes[offset_dims_seen++]; @@ -2697,12 +2700,12 @@ Status ValidateScatterDimensionNumbers( tensorflow::gtl::ArraySlice<int64> scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.update_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", Join(dim_numbers.update_window_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.update_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", @@ -2719,12 +2722,12 @@ Status ValidateScatterDimensionNumbers( } // Validate inserted_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", Join(dim_numbers.inserted_window_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", @@ -2764,8 +2767,8 @@ Status ValidateScatterDimensionNumbers( std::vector<int64> sorted_scatter_dims_to_operand_dims( dim_numbers.scatter_dims_to_operand_dims().begin(), dim_numbers.scatter_dims_to_operand_dims().end()); - c_sort(sorted_scatter_dims_to_operand_dims); - if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + absl::c_sort(sorted_scatter_dims_to_operand_dims); + if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) != sorted_scatter_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " @@ -2857,7 +2860,7 @@ Status ValidateScatterDimensionNumbers( int64 scatter_dims_seen = 0; for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { bool is_update_window_dim = - c_binary_search(scatter_dim_numbers.update_window_dims(), i); + absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { continue; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 62af45128a..aab1180662 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( std::vector<HloInstruction*> users; users.reserve(old_instr->user_count()); - c_copy(old_instr->users(), std::back_inserter(users)); + absl::c_copy(old_instr->users(), std::back_inserter(users)); for (auto* user : users) { for (int64 i = 0, e = user->operand_count(); i < e; i++) { @@ -108,10 +109,10 @@ StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) { // // This will let us sink the constant into the outer while first and then // into the inner while in a single run of this pass. - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 09ddcffb22..cb132d4f16 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" @@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy( }; InlinedVector<HloInstruction*, 4> new_operands; - c_transform(old_instruction->operands(), std::back_inserter(new_operands), - get_new_operand); + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); HloInstruction* new_instruction = parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( @@ -197,7 +198,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( op->opcode() == HloOpcode::kConstant; }; - if (!c_all_of(instruction->operands(), is_invariant)) { + if (!absl::c_all_of(instruction->operands(), is_invariant)) { continue; } @@ -257,10 +258,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { bool changed = false; std::vector<HloInstruction*> while_instrs; for (auto* comp : module->computations()) { - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1ef17b9d7d..52d9c3e5ae 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -206,7 +207,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues( HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); init_values_with_indvar.push_back(zero); - c_copy(init_values, std::back_inserter(init_values_with_indvar)); + absl::c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( HloInstruction::CreateTuple(init_values_with_indvar)); } @@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector<Shape> loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); - c_transform(init_values, std::back_inserter(loop_state_shape_components), - [](HloInstruction* instr) { return instr->shape(); }); + absl::c_transform(init_values, + std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 2ccb919acf..5e69419333 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" @@ -206,7 +207,7 @@ ENTRY main { auto is_while = [](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kWhile; }; - EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); + EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e280492bd9..eac8f977fa 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -127,6 +127,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -385,6 +386,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1542,17 +1544,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 1adc68cc48..7a203d6873 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -447,11 +448,11 @@ std::vector<float> GetInterestingF16ConversionTestCases() { XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::vector<float> test_cases = GetInterestingF16ConversionTestCases(); std::vector<half> input; - c_transform(test_cases, std::back_inserter(input), - [](float f) { return Eigen::half(f); }); + absl::c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); std::vector<float> expected_output; - c_transform(input, std::back_inserter(expected_output), - [](Eigen::half h) { return static_cast<float>(h); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast<float>(h); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> dot_lhs_handle, @@ -470,8 +471,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::vector<float> input = GetInterestingF16ConversionTestCases(); std::vector<half> expected_output; - c_transform(input, std::back_inserter(expected_output), - [](float f) { return Eigen::half(f); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> dot_lhs_handle, diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index e2c16a9e59..b6b8c43bd9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -215,7 +216,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector<Literal*> fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); @@ -229,7 +230,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector<Literal*> fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); @@ -264,7 +265,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( MakeFakeArguments(module_or_status.ValueOrDie().get()) .ConsumeValueOrDie(); std::vector<Literal*> fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); return test_runner_ @@ -319,8 +320,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( HloComputation* HloTestBase::FindComputation(HloModule* module, tensorflow::StringPiece name) { auto computations = module->computations(); - auto it = c_find_if(computations, - [&](HloComputation* c) { return c->name() == name; }); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); if (it == computations.end()) { return nullptr; } @@ -331,8 +332,8 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, tensorflow::StringPiece name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); - auto it = c_find_if(instructions, - [&](HloInstruction* i) { return i->name() == name; }); + auto it = absl::c_find_if( + instructions, [&](HloInstruction* i) { return i->name() == name; }); if (it != instructions.end()) { return *it; } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 11f3efb1f3..e12e095ecd 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -116,7 +117,7 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); } @@ -294,7 +295,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { tensorflow::str_util::Split(profile_output, '\n'); auto while_body_profile_start = - c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { + absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { return tensorflow::str_util::StartsWith(s, "Execution profile for body"); }); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 5ae099a462..cc07346ee5 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,6 +24,7 @@ limitations under the License. #include <type_traits> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -434,122 +435,15 @@ std::vector<std::pair<int64, int64>> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); -template <typename Container, typename Predicate> -bool c_all_of(const Container& container, Predicate&& predicate) { - return std::all_of(std::begin(container), std::end(container), - std::forward<Predicate>(predicate)); -} - -template <typename Container, typename Predicate> -bool c_any_of(const Container& container, Predicate&& predicate) { - return std::any_of(std::begin(container), std::end(container), - std::forward<Predicate>(predicate)); -} - -template <typename InputContainer, typename OutputIterator, - typename UnaryOperation> -OutputIterator c_transform(const InputContainer& input_container, - OutputIterator output_iterator, - UnaryOperation&& unary_op) { - return std::transform(std::begin(input_container), std::end(input_container), - output_iterator, - std::forward<UnaryOperation>(unary_op)); -} - -template <class InputContainer, class OutputIterator, class UnaryPredicate> -OutputIterator c_copy_if(const InputContainer& input_container, - OutputIterator output_iterator, - UnaryPredicate&& predicate) { - return std::copy_if(std::begin(input_container), std::end(input_container), - output_iterator, std::forward<UnaryPredicate>(predicate)); -} - -template <class InputContainer, class OutputIterator> -OutputIterator c_copy(const InputContainer& input_container, - OutputIterator output_iterator) { - return std::copy(std::begin(input_container), std::end(input_container), - output_iterator); -} - -template <class InputContainer> -void c_sort(InputContainer& input_container) { - std::sort(std::begin(input_container), std::end(input_container)); -} - -template <class InputContainer, class Comparator> -void c_sort(InputContainer& input_container, Comparator&& comparator) { - std::sort(std::begin(input_container), std::end(input_container), - std::forward<Comparator>(comparator)); -} - -template <typename Sequence, typename T> -bool c_binary_search(const Sequence& sequence, T&& value) { - return std::binary_search(std::begin(sequence), std::end(sequence), - std::forward<T>(value)); -} - -template <typename C> -bool c_is_sorted(const C& c) { - return std::is_sorted(std::begin(c), std::end(c)); -} - -template <typename C, typename Compare> -bool c_is_sorted(const C& c, Compare&& comp) { - return std::is_sorted(std::begin(c), std::end(c), - std::forward<Compare>(comp)); -} - -template <typename C> -auto c_adjacent_find(C& c) -> decltype(std::begin(c)) { - return std::adjacent_find(std::begin(c), std::end(c)); -} - -template <typename C, typename Pred> -auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) { - return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred)); -} - -template <typename C, typename Value> -auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) { - return std::find(std::begin(c), std::end(c), std::forward<Value>(value)); -} - -template <typename Sequence> -void c_reverse(Sequence& sequence) { - std::reverse(std::begin(sequence), std::end(sequence)); -} - -template <typename Sequence, typename T, typename BinaryOp> -typename std::decay<T>::type c_accumulate(const Sequence& sequence, T&& init, - BinaryOp&& binary_op) { - return std::accumulate(std::begin(sequence), std::end(sequence), - std::forward<T>(init), - std::forward<BinaryOp>(binary_op)); -} - -template <typename C, typename Pred> -typename std::iterator_traits< - decltype(std::begin(std::declval<C>()))>::difference_type -c_count_if(const C& c, Pred&& pred) { - return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred)); -} - -// Determines whether `value` is present in `c`. -template <typename C, typename T> -bool c_linear_search(const C& c, T&& value) { - auto last = std::end(c); - return std::find(std::begin(c), last, std::forward<T>(value)) != last; -} - template <typename C, typename Value> int64 FindIndex(const C& c, Value&& value) { - auto it = c_find(c, std::forward<Value>(value)); + auto it = absl::c_find(c, std::forward<Value>(value)); return std::distance(c.begin(), it); } template <typename T> bool ArrayContains(tensorflow::gtl::ArraySlice<T> c, const T& value) { - return c_find(c, value) != c.end(); + return absl::c_find(c, value) != c.end(); } template <typename C, typename Value> @@ -584,8 +478,8 @@ bool IsInt32(T x) { template <typename T> Status EraseElementFromVector(std::vector<T>* container, const T& value) { - // c_find returns a const_iterator which does not seem to work on gcc 4.8.4, - // and this breaks the ubuntu/xla_gpu build bot. + // absl::c_find returns a const_iterator which does not seem to work on + // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot. auto it = std::find(container->begin(), container->end(), value); TF_RET_CHECK(it != container->end()); container->erase(it); |