aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD197
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc71
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc188
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc22
-rw-r--r--tensorflow/compiler/xla/service/backend.cc7
-rw-r--r--tensorflow/compiler/xla/service/backend.h4
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc12
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc4
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc28
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc98
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc24
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc116
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.cc3
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc40
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h6
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.cc2
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h8
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc16
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc8
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc16
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h6
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc5
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc75
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h26
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD24
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc105
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc24
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc31
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc15
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/disassembler.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc101
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc430
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h21
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc5
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc4
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc9
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.cc4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h12
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc1
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc865
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h119
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/executable.cc13
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc10
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc11
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD47
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc21
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc34
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc134
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h40
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc115
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc417
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc39
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc160
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc59
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk_schedule.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc9
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc43
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h24
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto13
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc149
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc233
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc132
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h192
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc210
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc430
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h126
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc419
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h150
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc472
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc138
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc192
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h37
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc260
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc429
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc292
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h63
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc85
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc41
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.h18
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc128
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc29
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h4
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD8
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc10
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/platform.cc11
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc249
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h7
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD22
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h400
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h48
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc42
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h32
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc57
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h29
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h8
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc26
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.cc11
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h16
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc17
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h4
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h9
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc35
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc4
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc5
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/service.cc81
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc814
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc15
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc8
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.h40
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc25
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h20
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc2
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc58
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h8
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc11
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc19
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc12
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
352 files changed, 7815 insertions, 4785 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index a65bdebf51..4aef093b04 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -99,6 +99,7 @@ cc_library(
":bfloat16_support",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
@@ -175,6 +176,9 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -226,6 +230,7 @@ cc_library(
hdrs = ["hlo_evaluator.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_query",
":shape_inference",
"//tensorflow/compiler/xla:literal",
@@ -237,6 +242,11 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -263,6 +273,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -311,6 +322,10 @@ cc_library(
"//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -337,7 +352,7 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -389,7 +404,8 @@ cc_library(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -449,6 +465,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -517,6 +536,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -552,6 +572,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -574,6 +595,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -615,6 +638,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)
@@ -647,6 +673,9 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -669,6 +698,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -719,6 +749,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -736,6 +769,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:ptr_util",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -766,6 +800,8 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -813,6 +849,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -831,6 +869,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -847,6 +887,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -864,6 +905,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -874,6 +917,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -908,6 +952,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -917,12 +963,14 @@ tf_cc_test(
deps = [
":buffer_liveness",
":hlo",
+ ":hlo_dataflow_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -950,6 +998,9 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -977,6 +1028,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -996,6 +1048,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1031,6 +1085,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1049,6 +1104,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1059,12 +1115,15 @@ cc_library(
deps = [
":hlo",
":hlo_casting_utils",
+ ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1074,6 +1133,7 @@ cc_library(
hdrs = ["hlo_module_group_util.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_group_metadata",
":hlo_reachability",
"//tensorflow/compiler/xla:status",
@@ -1082,6 +1142,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1101,6 +1163,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
)
@@ -1108,17 +1171,18 @@ tf_cc_test(
name = "hlo_scheduling_test",
srcs = ["hlo_scheduling_test.cc"],
deps = [
- ":buffer_value",
":heap_simulator",
":hlo",
+ ":hlo_dce",
":hlo_ordering",
+ ":hlo_parser",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -1142,6 +1206,7 @@ cc_library(
":hlo_pass",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1167,6 +1232,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1181,6 +1247,9 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1198,6 +1267,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1216,6 +1286,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1231,6 +1302,7 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1245,6 +1317,7 @@ cc_library(
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1267,6 +1340,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1276,6 +1350,7 @@ cc_library(
hdrs = ["algebraic_simplifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_creation_utils",
":hlo_pass",
":hlo_query",
@@ -1289,6 +1364,10 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1298,6 +1377,7 @@ tf_cc_test(
deps = [
":algebraic_simplifier",
":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
@@ -1312,6 +1392,8 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1323,8 +1405,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":hlo_pass",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1377,6 +1458,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1414,6 +1496,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1439,8 +1523,7 @@ cc_library(
deps = [
":hlo",
":hlo_evaluator",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1455,6 +1538,8 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1468,6 +1553,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1582,6 +1668,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1602,6 +1689,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1635,6 +1723,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -1654,6 +1743,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = True, # Contains per-platform computation placer registration
)
@@ -1667,6 +1758,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1744,6 +1837,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1758,6 +1853,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1789,6 +1885,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1805,6 +1903,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1820,6 +1919,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
],
)
@@ -1847,6 +1947,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
],
)
@@ -1864,6 +1965,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1882,6 +1985,9 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1923,6 +2029,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1959,6 +2067,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1979,6 +2088,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2016,6 +2126,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -2028,7 +2139,6 @@ cc_library(
":hlo_dataflow_analysis",
":logical_buffer",
":logical_buffer_analysis",
- "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -2036,6 +2146,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2086,6 +2200,9 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2108,6 +2225,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2175,7 +2293,10 @@ cc_library(
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2212,13 +2333,16 @@ cc_library(
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
- ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2258,6 +2382,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2339,6 +2464,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2376,6 +2504,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2392,6 +2521,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2402,6 +2532,7 @@ tf_cc_test(
":hlo",
":hlo_constant_folding",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2423,6 +2554,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2437,6 +2569,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2497,6 +2630,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2552,6 +2686,7 @@ cc_library(
hdrs = ["elemental_ir_emitter.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_config",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -2560,11 +2695,14 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:transform_utils",
],
@@ -2596,10 +2734,11 @@ cc_library(
":computation_layout",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -2612,6 +2751,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2648,8 +2788,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -2683,6 +2823,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
alwayslink = 1,
)
@@ -2699,6 +2842,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2780,9 +2924,9 @@ cc_library(
hdrs = ["stream_pool.h"],
deps = [
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -2880,6 +3024,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -2926,7 +3071,8 @@ cc_library(
":hlo_creation_utils",
":tuple_util",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -2940,6 +3086,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2955,6 +3102,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2982,6 +3131,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -3015,13 +3165,13 @@ cc_library(
cc_library(
name = "source_map_util",
- srcs = ["source_map_util.cc"],
+ srcs = [],
hdrs = ["source_map_util.h"],
deps = [
":executable",
"//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -3036,6 +3186,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -3067,8 +3221,11 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -3077,11 +3234,13 @@ tf_cc_test(
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
+ ":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main", # fixdeps: keep
+ "@com_google_absl//absl/strings",
],
)
@@ -3100,6 +3259,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index f7812d9661..19bb4da9a6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -22,13 +22,19 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -41,7 +47,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -266,7 +271,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
@@ -540,7 +545,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -827,18 +832,18 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
TF_ASSIGN_OR_RETURN(
HloInstruction * optimized_lhs_concat,
- OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs,
+ OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
rhs_contracting_dim, /*swapped=*/false));
if (optimized_lhs_concat) {
return optimized_lhs_concat;
}
- return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs,
+ return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
lhs_contracting_dim, /*swapped=*/true);
}
StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
lhs->concatenate_dimension() == lhs_contracting_dim &&
@@ -937,11 +942,12 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
}
auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums));
+ dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
+ new_dot->set_precision_config(dot.precision_config());
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
- dot_shape, HloOpcode::kAdd, add_result, new_dot));
+ dot.shape(), HloOpcode::kAdd, add_result, new_dot));
} else {
add_result = new_dot;
}
@@ -1040,6 +1046,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
memoized_shape, left_operand, right_operand, dnums));
+ memoized_inst->set_precision_config(dot->precision_config());
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1137,6 +1144,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
rhs->mutable_operand(0), lhs->mutable_operand(0),
dot_dimension_numbers));
+ new_dot->set_precision_config(dot->precision_config());
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -1232,7 +1240,7 @@ namespace {
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
-std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
const HloInstruction* hlo,
tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
@@ -1252,11 +1260,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
- return std::make_pair(false, std::vector<int64>());
+ return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
- return std::make_pair(true, output_dim_indices);
+ return output_dim_indices;
}
// Returns true if the output of "instruction" is a permutation of the
@@ -1385,6 +1393,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+ // broadcast(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateIota(
+ broadcast->shape(),
+ dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
+ }
+
// Merge two consecutive broadcasts into a single one.
if (operand->opcode() == HloOpcode::kBroadcast) {
std::vector<int64> new_dimensions;
@@ -1713,12 +1730,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
reshape, reshape->operand(0)->dimensions());
- if (opt_dims.first) {
+ if (opt_dims.has_value()) {
return ReplaceWithNewInstruction(
reshape,
HloInstruction::CreateBroadcast(
reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
- opt_dims.second));
+ *opt_dims));
+ }
+ }
+
+ // reshape(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ auto* iota = Cast<HloIotaInstruction>(operand);
+ auto opt_dims =
+ ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
+ if (opt_dims.has_value()) {
+ CHECK_EQ(opt_dims->size(), 1);
+ return ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
}
}
@@ -1752,8 +1782,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
}
auto is_unstrided_slice = [](const HloInstruction* hlo) {
- return c_all_of(hlo->slice_strides(),
- [](int64 stride) { return stride == 1; });
+ return absl::c_all_of(hlo->slice_strides(),
+ [](int64 stride) { return stride == 1; });
};
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) {
@@ -1930,7 +1960,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// This should make fusion easier or use less memory bandwidth in the unfused
// case.
if (arg->opcode() == HloOpcode::kConcatenate &&
- c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) {
+ absl::c_linear_search(reduce->dimensions(),
+ arg->concatenate_dimension())) {
HloInstruction* old_reduce = nullptr;
for (HloInstruction* operand : arg->operands()) {
HloInstruction* new_reduce = computation_->AddInstruction(
@@ -1983,9 +2014,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
- << (convert != nullptr ? tensorflow::strings::StrCat(
- "\nvia convert: ", convert->ToString())
- : "");
+ << (convert != nullptr
+ ? absl::StrCat("\nvia convert: ", convert->ToString())
+ : "");
// Do not fold interior padding into ReduceWindow since the backends do not
// support it.
@@ -2294,6 +2325,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
+ dot->set_precision_config(convolution->precision_config());
+
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index c48196e861..b864c372fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface {
enable_dot_strength_reduction_(enable_dot_strength_reduction),
enable_conv_simplification_(enable_conv_simplification) {}
~AlgebraicSimplifier() override = default;
- tensorflow::StringPiece name() const override { return "algsimp"; }
+ absl::string_view name() const override { return "algsimp"; }
// Run algebraic simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 5837391d75..1900a05750 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -18,11 +18,15 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -34,13 +38,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-using ::testing::ElementsAre;
namespace xla {
namespace {
+using ::testing::ElementsAre;
+
namespace op = xla::testing::opcode_matchers;
AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
@@ -51,7 +54,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {
+ public:
+ AlgebraicSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -1820,6 +1828,105 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
op::Reshape(op::Broadcast(param)));
}
+TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(HloInstruction::CreateIota(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
+ Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
+ builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension(),
+ 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ const int64 iota_dim =
+ Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension();
+ EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
@@ -2037,7 +2144,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
- auto build_and_simplify = [&options]() -> string {
+ auto build_and_simplify = [&]() -> string {
HloComputation::Builder b(TestName());
Window window;
@@ -2143,9 +2250,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
root->operand(0)->opcode() == HloOpcode::kDot) {
auto lhs_shape = root->operand(0)->operand(0)->shape();
auto rhs_shape = root->operand(0)->operand(1)->shape();
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
- tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
+ return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
+ absl::StrJoin(rhs_shape.dimensions(), "x"));
}
return "UNEXPECTED CHANGE";
};
@@ -2648,6 +2754,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
}
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
+ HloComputation::Builder builder(TestName());
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;
@@ -2660,11 +2807,10 @@ struct PadReduceWindowEffectiveBroadcastCase {
bool should_become_broadcast;
string ToTestCaseName() const {
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(input_spatials, ","), ";",
- tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
- tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
- ";", should_become_broadcast);
+ return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
+ absl::StrJoin(symmetric_pad_spatials, ","), ";",
+ absl::StrJoin(reduce_window_spatials, ","), ";",
+ prepend_a, ";", should_become_broadcast);
}
};
@@ -2852,7 +2998,12 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {
+ public:
+ DotOfConcatSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3025,7 +3176,12 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {
+ public:
+ DotOfGatherSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 51ebc4763b..1ed6142dce 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -17,15 +17,15 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -69,8 +69,7 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
return InvalidArgument(
"AllocationTracker for platform %s cannot register buffer from "
"platform %s",
- backend_->platform()->Name().c_str(),
- shaped_buffer.platform()->Name().c_str());
+ backend_->platform()->Name(), shaped_buffer.platform()->Name());
}
}
@@ -91,8 +90,9 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
// If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
// into a regular ShapedBuffer, which is stored in
// handle_to_shaped_buffers_.
- handle_to_shaped_buffers_[handle].emplace_back(MakeUnique<ShapedBuffer>(
- ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
+ handle_to_shaped_buffers_[handle].emplace_back(
+ absl::make_unique<ShapedBuffer>(
+ ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
}
GlobalDataHandle result;
@@ -124,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
// "handle does not exist".
auto it = handle_to_shaped_buffers_.find(data.handle());
if (it == handle_to_shaped_buffers_.end()) {
- return NotFound("no allocation record for global data handle: %lld",
+ return NotFound("no allocation record for global data handle: %d",
data.handle());
}
for (auto& shaped_buffer : it->second) {
@@ -143,7 +143,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
// the same for all buffers across replicas.
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) {
- return InvalidArgument("global data handle %lld is not a tuple",
+ return InvalidArgument("global data handle %d is not a tuple",
data.handle());
}
// If the on-host representation is a tuple, then the on-device one should be
@@ -200,14 +200,14 @@ StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
VLOG(2) << "resolve:" << data.handle();
auto it = handle_to_shaped_buffers_.find(data.handle());
if (it == handle_to_shaped_buffers_.end()) {
- return NotFound("no allocation record for global data handle: %lld",
+ return NotFound("no allocation record for global data handle: %d",
data.handle());
}
std::vector<const ShapedBuffer*> replicated_buffers;
for (const auto& shaped_buffer : it->second) {
if (shaped_buffer == nullptr) {
- return InvalidArgument(
- "global data handle %lld was previously deallocated", data.handle());
+ return InvalidArgument("global data handle %d was previously deallocated",
+ data.handle());
}
replicated_buffers.push_back(shaped_buffer.get());
}
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index d12be3e007..a6889cb171 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -127,8 +128,8 @@ Backend::Backend(
}
}
// Create a memory allocator for the valid stream executors.
- memory_allocator_ =
- MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
+ memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>(
+ platform, stream_executors);
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';
@@ -176,7 +177,7 @@ StatusOr<se::StreamExecutor*> Backend::stream_executor(
}
}
return InvalidArgument("device %s not supported by XLA service",
- device_name(device_ordinal).c_str());
+ device_name(device_ordinal));
}
StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 1bc3796fa4..4a6a78daf0 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -130,7 +130,7 @@ class Backend {
// Return a string identifier for the given device, eg: "GPU:3".
string device_name(int device_ordinal) const {
- return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal);
+ return absl::StrCat(platform_->Name(), ":", device_ordinal);
}
// Returns true if the devices with the given ordinals are equivalent from
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index 2099916509..a16b85a0a5 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -63,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
+ new_dot->set_precision_config(batch_dot->precision_config());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
@@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
return true;
}
-tensorflow::StringPiece BatchDotSimplification::name() const {
+absl::string_view BatchDotSimplification::name() const {
return "batch-dot-simplification";
}
@@ -84,10 +86,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> dot_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
- [](HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kDot;
- });
+ absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
+ [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kDot;
+ });
}
for (HloInstruction* dot_instr : dot_instrs) {
TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index c0ca8d8eba..79d37f08d3 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -28,7 +28,7 @@ namespace xla {
class BatchDotSimplification : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
private:
StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot(
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index 38f1a5d3a6..b342acb025 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,7 +24,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {
+ public:
+ BatchDotSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index c4cd60c120..01931b2d02 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -43,7 +43,7 @@ namespace xla {
namespace {
-using tensorflow::gtl::optional;
+using absl::optional;
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 7ae202c583..76e32174f3 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface {
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
- tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
+ absl::string_view name() const override { return "batchnorm_expander"; }
// Run operation expander on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index a725351462..aba0d9bb5b 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index c939838709..5dcd31b83d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16ConversionFolding() override = default;
- tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
+ absl::string_view name() const override { return "bfloat16-fold"; }
// Run BF16 conversion folding on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 7cf05ca443..6363a21c3b 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -235,8 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
- sum, /*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ sum, /*replica_groups=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 16e99b5722..32573ed355 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
- // Special handling for cross-replica-sum and sort which can have a tuple
- // output.
- Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleSort(HloInstruction* sort) override;
-
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
@@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
return Status::OK();
}
-Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
- HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape())) {
- return HandleInstruction(crs);
- } else {
- return HandleMultipleOutputs(crs);
- }
-}
-
-Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
- if (!ShapeUtil::IsTuple(sort->shape())) {
- return HandleInstruction(sort);
- } else {
- return HandleMultipleOutputs(sort);
- }
-}
-
Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
HloInstruction* hlo) {
std::vector<PrimitiveType> operand_types(hlo->operand_count());
@@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ if ((hlo->opcode() == HloOpcode::kSort ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
+ ShapeUtil::IsTuple(hlo->shape())) {
+ return HandleMultipleOutputs(hlo);
+ }
return HandleInstruction(hlo);
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 2a60fe0af3..30b6346312 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16Normalization() override = default;
- tensorflow::StringPiece name() const override { return "bf16-normalization"; }
+ absl::string_view name() const override { return "bf16-normalization"; }
// Run BF16 normalization on the given computation. Returns whether the
// computation was changed.
@@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface {
~BFloat16MixedPrecisionRemoval() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "bf16-mixed-precision-removal";
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index f9f1f64998..b08705d4c2 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase {
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
- HloVerifier verifier(/*allow_mixed_precision=*/true);
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true);
EXPECT_IS_OK(verifier.Run(module).status());
return result.ValueOrDie();
@@ -251,8 +252,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
- /*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ /*replica_groups=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 02b8cad089..1ee64971ab 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface {
~BFloat16Propagation() override = default;
- tensorflow::StringPiece name() const override {
- return "bfloat16-propagation";
- }
+ absl::string_view name() const override { return "bfloat16-propagation"; }
// Runs the pass on the given module. Returns whether the module was changed
// (precision reductions were added).
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index cfd26fc778..b11f15ec7b 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,8 +22,10 @@ limitations under the License.
#include <ostream>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -36,20 +38,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
+namespace {
+using absl::StrAppend;
+using absl::StrAppendFormat;
using ::tensorflow::gtl::FlatMap;
using ::tensorflow::gtl::FlatSet;
-using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::HumanReadableNumBytes;
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-
-namespace {
template <typename T>
string ColocatedBufferSetsToString(const T& container, const char* title) {
@@ -107,7 +104,7 @@ Status GatherComputationsByAllocationType(
return InvalidArgument(
"computation %s has conflicting allocation requirements (global "
"and thread-local)",
- computation->name().c_str());
+ computation->name());
}
if (is_thread_local) {
@@ -130,7 +127,7 @@ Status GatherComputationsByAllocationType(
return InvalidArgument(
"computation %s cannot contain call/while op because it "
"requires thread-local buffer allocations",
- computation->name().c_str());
+ computation->name());
}
worklist.push_back(std::make_pair(subcomputation,
false)); // Not thread local.
@@ -147,9 +144,8 @@ Status GatherComputationsByAllocationType(
true)); // Thread local.
break;
default:
- return InternalError(
- "Unexpected calling opcode: %s",
- HloOpcodeString(instruction->opcode()).c_str());
+ return InternalError("Unexpected calling opcode: %s",
+ HloOpcodeString(instruction->opcode()));
}
}
}
@@ -236,8 +232,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
}
string BufferAllocation::Slice::ToString() const {
- return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_,
- ", size:", size_, "}");
+ return absl::StrCat("{index:", index(), ", offset:", offset_,
+ ", size:", size_, "}");
}
BufferAllocation::Slice BufferAllocation::GetSlice(
@@ -298,7 +294,7 @@ BufferAllocationProto BufferAllocation::ToProto() const {
string BufferAllocation::ToString() const {
string output;
- Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size());
+ StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
if (color().value() != 0) {
StrAppend(&output, ", color ", color().value());
}
@@ -330,11 +326,10 @@ string BufferAllocation::ToString() const {
});
for (const LogicalBuffer* buffer : sorted_buffers) {
const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
- StrAppend(&output,
- tensorflow::strings::Printf(
- " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(),
- offset_size.offset, offset_size.size,
- ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str()));
+ StrAppend(&output, absl::StrFormat(
+ " %s [%d,%d]: %s\n", buffer->ToString(),
+ offset_size.offset, offset_size.size,
+ ShapeUtil::HumanStringWithLayout(buffer->shape())));
}
return output;
}
@@ -427,7 +422,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
return FailedPrecondition(
"BufferAllocation::Slice for instruction %s at index %s cannot "
"be determined at compile-time.",
- instruction->name().c_str(), index.ToString().c_str());
+ instruction->name(), index.ToString());
}
} else {
VLOG(3) << "No allocation";
@@ -436,7 +431,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
if (result.allocation() == nullptr) {
return FailedPrecondition(
"BufferAllocation::Slice not assigned for instruction %s at index %s",
- instruction->name().c_str(), index.ToString().c_str());
+ instruction->name(), index.ToString());
}
return result;
}
@@ -627,7 +622,7 @@ Status BufferAssignment::ComputeSummaryStats() {
stats_.total_allocation_bytes += allocation.size();
}
- // Only compute total fragmentation if all computations are sequential.
+ // Only compute total fragmentation if all computations have schedules.
SequentialHloOrdering::HloModuleSequence module_sequence;
for (const auto& computation : module_->computations()) {
const std::vector<const HloInstruction*>* sequence =
@@ -648,39 +643,38 @@ Status BufferAssignment::ComputeSummaryStats() {
string BufferAssignment::Stats::ToString() const {
string s;
- Appendf(&s, "BufferAssignment stats:\n");
- Appendf(&s, " parameter allocation: %10s\n",
- HumanReadableNumBytes(parameter_allocation_bytes).c_str());
- Appendf(&s, " constant allocation: %10s\n",
- HumanReadableNumBytes(constant_allocation_bytes).c_str());
- Appendf(&s, " maybe_live_out allocation: %10s\n",
- HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str());
- Appendf(&s, " preallocated temp allocation: %10s\n",
- HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str());
+ StrAppendFormat(&s, "BufferAssignment stats:\n");
+ StrAppendFormat(&s, " parameter allocation: %10s\n",
+ HumanReadableNumBytes(parameter_allocation_bytes));
+ StrAppendFormat(&s, " constant allocation: %10s\n",
+ HumanReadableNumBytes(constant_allocation_bytes));
+ StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
+ HumanReadableNumBytes(maybe_live_out_allocation_bytes));
+ StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
+ HumanReadableNumBytes(preallocated_temp_allocation_bytes));
if (preallocated_temp_fragmentation_bytes >= 0) {
const double percent = 100. * preallocated_temp_fragmentation_bytes /
preallocated_temp_allocation_bytes;
- Appendf(
+ StrAppendFormat(
&s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
- HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(),
- percent);
+ HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
}
- Appendf(&s, " total allocation: %10s\n",
- HumanReadableNumBytes(total_allocation_bytes).c_str());
+ StrAppendFormat(&s, " total allocation: %10s\n",
+ HumanReadableNumBytes(total_allocation_bytes));
if (total_fragmentation_bytes >= 0) {
const double percent =
100. * total_fragmentation_bytes / total_allocation_bytes;
- Appendf(&s, " total fragmentation: %10s (%.2f%%)\n",
- HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent);
+ StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
+ HumanReadableNumBytes(total_fragmentation_bytes), percent);
}
return s;
}
string BufferAssignment::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "BufferAssignment:\n");
+ absl::StrAppend(&output, "BufferAssignment:\n");
for (auto& allocation : allocations_) {
- tensorflow::strings::StrAppend(&output, allocation.ToString());
+ absl::StrAppend(&output, allocation.ToString());
}
return output;
}
@@ -1100,8 +1094,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
+ HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
assignment->module(), module_sequence,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
@@ -1130,11 +1124,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
- *computation, *instruction_sequence,
- assignment->points_to_analysis(),
- assignment->buffer_size_, options));
+ HeapSimulator::Run(
+ absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
+ *computation, *instruction_sequence,
+ assignment->points_to_analysis(), assignment->buffer_size_,
+ options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
@@ -1646,7 +1641,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
XLA_VLOG_LINES(3, liveness->ToString());
XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
- // Can't use MakeUnique because BufferAssignment constructor is private.
+ // Can't use absl::make_unique because BufferAssignment constructor is
+ // private.
std::unique_ptr<BufferAssignment> assignment(
new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
std::move(color_alignment)));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index eccb146a0d..52abda16c4 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
@@ -87,7 +87,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -98,7 +98,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
HloModule* module, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -109,7 +109,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -127,7 +127,8 @@ class BufferAssignmentTest : public HloTestBase {
instruction_sequence.end());
return BufferAssigner::Run(
module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ absl::make_unique<SequentialHloOrdering>(module,
+ module_sequence),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -1769,7 +1770,8 @@ class WhileBufferAssignmentTest : public HloTestBase {
auto sequence =
ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
- module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
+ module,
+ absl::make_unique<SequentialHloOrdering>(module, sequence),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -2083,7 +2085,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto assignment,
BufferAssigner::Run(
module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+ absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
@@ -2340,7 +2342,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto assignment =
BufferAssigner::Run(
module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+ absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true)
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 810d597e73..9b2783a214 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -28,8 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() {
string BufferLiveness::ToString() const {
std::vector<string> pieces;
- pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):",
- module_->name().c_str()));
+ pieces.push_back(
+ absl::StrFormat("BufferLiveness(module=%s):", module_->name()));
pieces.push_back("HloOrdering:");
pieces.push_back(hlo_ordering_->ToString());
- pieces.push_back(tensorflow::strings::Printf("Aliased buffers:"));
+ pieces.push_back("Aliased buffers:");
for (const LogicalBuffer* buffer : aliased_buffers_) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
- pieces.push_back(tensorflow::strings::Printf("Live out buffers:"));
+ pieces.push_back("Live out buffers:");
for (const LogicalBuffer* buffer : maybe_live_out_buffers_) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b));
if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
return false;
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 4a927b5767..26e26e316d 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -18,8 +18,9 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -119,8 +120,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
@@ -167,10 +168,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
SequentialHloOrdering::HloModuleSequence sequence;
sequence.insert({entry, {param0, negate, param1, exp, add}});
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), sequence))
+ .ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -215,8 +216,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -249,8 +250,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@@ -293,10 +294,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
SequentialHloOrdering::HloModuleSequence module_sequence;
std::vector<const HloInstruction*> order = {param, negate, exp, add};
module_sequence.emplace(computation, order);
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence))
+ .ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -342,10 +343,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
std::vector<const HloInstruction*> order = {param, add, recv,
recv_done, send, send_done};
module_sequence.emplace(computation, order);
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence))
+ .ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Check the root instruction (add) buffer interferes with the recv buffer.
@@ -376,8 +377,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// All buffers should be live out except the param
@@ -412,8 +413,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Buffers in different computations should always interfere.
@@ -453,8 +454,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Only the element buffers of the tuple constant which are pointed to by
@@ -518,8 +519,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -580,8 +581,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
- // Runs BufferLiveness on this computation.
- // Returns whether buffer interference is detected between tuple-shaped
- // parameter and root instructions at tuple element 1.
- bool Run(const bool update_uses_tuple_element1,
- const bool fuse_gte0 = false) {
+ std::unique_ptr<HloModule> BuildModule(const bool update_uses_tuple_element1,
+ const bool fuse_gte0) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
@@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
// Create output tuple.
- auto tuple_root = builder.AddInstruction(
+ builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = CreateNewModule();
- module->AddEntryComputation(BuildDummyComputation());
- auto* computation = module->AddEmbeddedComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
+ auto* computation = module->entry_computation();
// Create fusion instruction based on number of tuple element 1 users.
if (update_uses_tuple_element1) {
computation->CreateFusionInstruction(
@@ -666,16 +664,39 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
computation->CreateFusionInstruction({gte0},
HloInstruction::FusionKind::kLoop);
}
+ return module;
+ }
+ // Returns whether buffer interference is detected between tuple-shaped
+ // parameter and root instructions at tuple element 1.
+ bool Run(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
+ bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
+ // Run BufferLiveness on 'module'.
+ auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
+ auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
+ // Return whether or not buffers interference is detected between
+ // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
+ return hlo_ordering->MayInterfere(
+ dataflow->GetUniqueValueAt(tuple_param0, {1}),
+ dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
+ }
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
+ EXPECT_FALSE(
+ RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
@@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
+ EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
+ /*fuse_gte0=*/true));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
+ EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
@@ -780,10 +806,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc
index 2bc556a9e2..fdf822c666 100644
--- a/tensorflow/compiler/xla/service/buffer_value.cc
+++ b/tensorflow/compiler/xla/service/buffer_value.cc
@@ -17,11 +17,10 @@ limitations under the License.
#include <iosfwd>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 985ff30e80..23b2a32709 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,21 +17,21 @@ limitations under the License.
#include <queue>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::strings::Appendf;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppendFormat;
+using absl::StrCat;
string CallContextToString(CallContext context) {
switch (context) {
@@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
}
string CallSite::ToString() const {
- return StrCat(instruction()->name(), " calls in context ",
- CallContextToString(context()), ": ",
- tensorflow::str_util::Join(
- called_computations(), ", ",
+ return StrCat(
+ instruction()->name(), " calls in context ",
+ CallContextToString(context()), ": ",
+ absl::StrJoin(called_computations(), ", ",
[](string* out, const HloComputation* computation) {
out->append(computation->name());
}));
@@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() {
/* static */
std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
- // Constructor for CallGraph is private so MakeUnique can't be used.
- auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
+ // Constructor for CallGraph is private so absl::make_unique can't be used.
+ auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
VLOG(2) << "Building call graph for:";
XLA_VLOG_LINES(2, module->ToString());
@@ -356,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
string CallGraph::ToString() const {
string out;
- Appendf(&out, "Call graph for module %s:\n", module_->name().c_str());
+ StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
for (const CallGraphNode& node : nodes()) {
- Appendf(&out, "Computation %s:\n", node.computation()->name().c_str());
- Appendf(&out, " calls:\n");
+ StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
+ StrAppendFormat(&out, " calls:\n");
for (const HloComputation* callee : node.callees()) {
- Appendf(&out, " %s\n", callee->name().c_str());
+ StrAppendFormat(&out, " %s\n", callee->name());
}
- Appendf(&out, " called by:\n");
+ StrAppendFormat(&out, " called by:\n");
for (const HloComputation* caller : node.callers()) {
- Appendf(&out, " %s\n", caller->name().c_str());
+ StrAppendFormat(&out, " %s\n", caller->name());
}
- Appendf(&out, " callsites:\n");
+ StrAppendFormat(&out, " callsites:\n");
for (const CallSite& callsite : node.callsites()) {
- Appendf(&out, " %s\n", callsite.ToString().c_str());
+ StrAppendFormat(&out, " %s\n", callsite.ToString());
}
}
return out;
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 97d3811508..3af2ab5edf 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -15,8 +15,8 @@ limitations under the License.
// Call graph for an HLO module.
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
#include <ostream>
@@ -272,4 +272,4 @@ class CallGraph {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
index 256d05a73e..1d42140444 100644
--- a/tensorflow/compiler/xla/service/call_inliner.cc
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
if (it == subcomputation_hlo_to_new_hlo_.end()) {
return NotFound(
"Could not find mapping from subcomputation HLO %s to a cloned HLO.",
- subcomputation_hlo->ToString().c_str());
+ subcomputation_hlo->ToString());
}
return it->second;
}
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index a8345a394d..c5cd88b9ea 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
#include <deque>
@@ -35,11 +35,11 @@ class CallInliner : public HloPassInterface {
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
~CallInliner() override = default;
- tensorflow::StringPiece name() const override { return "CallInliner"; }
+ absl::string_view name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index ff968bca29..5d85a3f173 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
index 13008efed1..3c2d1ae6d8 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.cc
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/channel_tracker.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) {
Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
- return NotFound("channel handle not found: %lld", handle.handle());
+ return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
if (channel.type == ChannelHandle::HOST_TO_DEVICE) {
return FailedPrecondition(
"host-to-device channels cannot be used with a Send operation; "
- "channel handle: %lld",
+ "channel handle: %d",
handle.handle());
}
if (channel.has_sender) {
return FailedPrecondition(
"when registering send, passed a channel handle that is already used "
- "by a sender: %lld",
+ "by a sender: %d",
handle.handle());
}
channel.has_sender = true;
@@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
- return NotFound("channel handle not found: %lld", handle.handle());
+ return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
if (channel.type == ChannelHandle::DEVICE_TO_HOST) {
return FailedPrecondition(
"device-to-host channels cannot be used with a Recv operation; "
- "channel handle: %lld",
+ "channel handle: %d",
handle.handle());
}
@@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
if (channel.receiver_count >= 1) {
return FailedPrecondition(
"when registering recv, passed a channel handle that is already used "
- "by a receiver: %lld",
+ "by a receiver: %d",
handle.handle());
}
channel.receiver_count += 1;
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 7426672a7a..3079695e96 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime(
if (!directory_path.empty()) {
HloSnapshot hlo_snapshot;
*hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation;
- string filename = tensorflow::strings::StrCat(
- "computation_", instance.computation.id(), "__",
- instance.computation.entry_computation_name());
+ string filename =
+ absl::StrCat("computation_", instance.computation.id(), "__",
+ instance.computation.entry_computation_name());
const string& per_host_path = tensorflow::io::JoinPath(
directory_path, tensorflow::port::Hostname());
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 6b3b9820f0..687ecafe0c 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() {
return NotFound(
"could not find registered compiler for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
// And then we invoke the factory, placing the result into the mapping.
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
index cb61f3da39..af8f7f1027 100644
--- a/tensorflow/compiler/xla/service/computation_layout.cc
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <algorithm>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -52,9 +52,8 @@ string ComputationLayout::ToString() const {
for (auto& param_layout : parameter_layouts_) {
params.push_back(param_layout.ToString());
}
- return tensorflow::strings::StrCat("(",
- tensorflow::str_util::Join(params, ", "),
- ") => ", result_layout_.ToString());
+ return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ",
+ result_layout_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index 187ce568cb..2210a8578a 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -19,8 +19,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -29,12 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
namespace xla {
@@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
"computation_count=%d",
proto.replica_count(), proto.computation_count());
}
- auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
- proto.computation_count());
+ auto assignment = absl::make_unique<DeviceAssignment>(
+ proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
@@ -132,7 +132,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
return NotFound(
"could not find registered computation placer for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
if (it->second.placer == nullptr) {
@@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() {
} // namespace xla
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
- return xla::MakeUnique<xla::ComputationPlacer>();
+ return absl::make_unique<xla::ComputationPlacer>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index b7be3ba605..4ea3a13f28 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -28,8 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 063261e26d..3de50cbd7f 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -27,9 +27,7 @@ namespace xla {
// with their true or false computation as appropriate.
class ConditionalSimplifier : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "simplify-conditional";
- }
+ absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index c43a31b167..6c477da038 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
+ ConditionalSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 45252fc1ee..9c81a86bbb 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -214,7 +214,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(MakeUnique<Literal>(
+ auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(expanded_filter_shape.element_type()))));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
@@ -224,6 +224,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
convolution->window(), dim_numbers, /*feature_group_count=*/1);
+ new_convolution->set_precision_config(convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index f213cc8709..498894737f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface {
public:
ConvolutionFeatureGroupConverter() {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-feature-group-converter";
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 3e39c1bab1..1b7a7b36ea 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -31,18 +33,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace {
+using absl::StrAppend;
+
bool IsEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
@@ -381,7 +378,7 @@ class CopyRemover {
}
string ToString() const {
- string out = StrCat("CopyRemover, module ", module_->name(), "\n");
+ string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n");
StrAppend(&out, " Buffer values, in dependency order:\n");
for (const HloBuffer& buffer : alias_analysis_.buffers()) {
StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
@@ -863,16 +860,16 @@ class CopyRemover {
for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
values.push_back(p->value);
}
- return StrCat("{",
- Join(values, ", ",
- [](string* s, const HloValue* value) {
- StrAppend(s, value->ToShortString());
- }),
- "}");
+ return absl::StrCat("{",
+ absl::StrJoin(values, ", ",
+ [](string* s, const HloValue* value) {
+ StrAppend(s, value->ToShortString());
+ }),
+ "}");
}
string ToString() const {
- string out = StrCat("BufferValueTracker:\n");
+ string out = absl::StrCat("BufferValueTracker:\n");
StrAppend(&out, " Def-use chains in each buffer:\n");
for (const ValueNode* head : value_lists_) {
StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
@@ -880,10 +877,10 @@ class CopyRemover {
const ValueNode* p = head;
do {
StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
- Join(p->uses, "; ",
- [](string* s, const HloUse* use) {
- StrAppend(s, use->ToString());
- }),
+ absl::StrJoin(p->uses, "; ",
+ [](string* s, const HloUse* use) {
+ StrAppend(s, use->ToString());
+ }),
"\n");
p = p->next;
@@ -960,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
return Status::OK();
}
-// Add copies to address special constraints on the roots of computations not
-// related to live range interference:
-//
-// (1) Entry computation root must be unambiguous and distinct.
-//
-// (2) Any computation called by a kCall instruction must have an
-// unambiguous root.
-//
-// (3) Constants and parameters cannot be live out of the entry computation
-//
+Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ return AddSpecialCaseCopies(*call_graph, module);
+}
+
Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
@@ -1065,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
- // Special case copies are not eligible for later copy elision passes.
- indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
- if (has_copy) {
- HloInstruction* copy = *copies_added.mutable_element(index);
- if (copy != nullptr) {
- copy->SetCopyElisionAllowed(false);
- }
- }
- });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1081,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
return Status::OK();
}
-Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) {
+Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering,
+ HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
- DependencyHloOrdering ordering(module);
TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
return Status::OK();
}
@@ -1101,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- instruction->CopyElisionAllowed()) {
+ if (instruction->opcode() == HloOpcode::kCopy) {
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
}
}
@@ -1168,10 +1150,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ DependencyHloOrdering dep_ordering(module);
+ TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module));
- DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module));
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
@@ -1179,7 +1161,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ TF_DCHECK_OK(
+ VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module));
MaybeDumpModule("after copy insertion", *module);
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 5ba64b78a3..d308f6bc84 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -45,7 +45,7 @@ namespace xla {
// InstructionAliasSet::IsDistinct return true.
class CopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
// fusion_can_share_buffer: backend specific function that decides whether a
// fusion can share buffer with its operand.
@@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface {
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module);
- private:
- // Verifies that no HLO values have interfering live ranged assuming the
- // ordering used by copy insertion.
- Status VerifyNoLiveRangeInterference(HloModule* module);
+ // Add copies to address special constraints on the roots of computations not
+ // related to live range interference:
+ //
+ // (1) Entry computation root must be unambiguous and distinct.
+ //
+ // (2) Any computation called by a kCall instruction must have an
+ // unambiguous root.
+ //
+ // (3) Constants and parameters cannot be live out of the entry computation
+ //
+ Status AddSpecialCaseCopies(HloModule* module);
- Status AddCopiesToResolveInterference(HloModule* module);
+ // Verifies that no HLO values have interfering live ranges using the given
+ // ordering.
+ Status VerifyNoLiveRangeInterference(const HloOrdering& ordering,
+ HloModule* module);
+ private:
+ // Override which requires the caller to pass in a call graph.
Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module);
+ Status AddCopiesToResolveInterference(HloModule* module);
+
// Backend specific function that decides whether a fusion can share buffer
// with its operand.
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index fe1ef78533..4cd192873f 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -85,6 +86,9 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ":target_machine_features",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
@@ -178,6 +182,7 @@ cc_library(
":runtime_single_threaded_conv2d",
":runtime_single_threaded_fft",
":runtime_single_threaded_matmul",
+ "@com_google_absl//absl/memory",
"@llvm//:execution_engine",
"@llvm//:core",
"@llvm//:mc", # fixdeps: keep
@@ -229,6 +234,8 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:orc_jit",
],
)
@@ -271,11 +278,14 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:code_gen",
"@llvm//:core",
"@llvm//:support",
@@ -320,6 +330,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -330,12 +341,12 @@ cc_library(
hdrs = ["parallel_loop_emitter.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:core",
],
)
@@ -362,6 +373,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -382,6 +394,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -395,6 +408,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:mc",
"@llvm//:mc_disassembler",
"@llvm//:object",
@@ -418,6 +432,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
"@llvm//:analysis",
"@llvm//:core",
"@llvm//:ipo",
@@ -634,6 +649,8 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -648,6 +665,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -810,6 +828,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -846,6 +866,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -893,6 +914,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
"@llvm//:core",
"@llvm//:support",
],
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 128eea4828..73b03440cb 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses(
llvm::legacy::PassManagerBase* passes) const {
llvm::Triple target_triple(target_machine_->getTargetTriple());
auto target_library_info_impl =
- MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
+ absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
target_library_info_impl->addVectorizableFunctions(
VectorFunctionsForTargetLibraryInfoImpl());
passes->add(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 0985b9297f..098ce17a56 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -132,6 +132,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
hlo->window(), new_dnums));
+ new_conv->set_precision_config(hlo->precision_config());
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index e6fd1499ed..59437e88af 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface {
: target_machine_features_(*target_machine_features) {}
~ConvCanonicalization() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-canonicalization";
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index fde8fbd486..6420180b13 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -26,6 +26,8 @@ limitations under the License.
// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
// IWYU pragma: no_include "llvm/Config/Targets.def.inc"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
@@ -42,7 +44,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
@@ -101,8 +102,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace cpu {
@@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
-} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine) {
- LLVMTargetMachineFeatures target_machine_features(target_machine);
+} // namespace
- // Optimization pipeline.
- HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>();
+Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes through layout assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
- pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
+ pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
@@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
}
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
- [&target_machine_features](
- const HloInstruction& dot,
+ [&](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
+ return PotentiallyImplementedAsEigenDot(dot, *target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), &target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
+ return pipeline.Run(module).status();
+}
+
+Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes after layout assignment");
+ // After layout assignment, use a layout-sensitive verifier.
+ auto& after_layout_assn =
+ pipeline.AddPass<HloPassPipeline>("after layout assignment");
+ after_layout_assn.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
+
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
- "after layout assignement");
+ "simplification after layout assignement");
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
@@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
+
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
@@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
pipeline.AddPass<ParallelTaskAssigner>(
- max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
+ max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
}
- // Copy insertion should be performed immediately before IR emission to avoid
- // inserting unnecessary copies (later pass adds an instruction which
- // materializes the value) or missing a necessary copy (later pass removes an
- // instruction which materializes a value). DCE must be run immediately before
- // (and sometime after) copy insertion, to avoid dead code from interfering
- // with the rewrites.
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value). DCE must be run immediately
+ // before (and sometime after) copy insertion, to avoid dead code from
+ // interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
@@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
return pipeline.Run(module).status();
}
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+ TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
+ &target_machine_features));
+ return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
+ &target_machine_features);
+}
+
namespace {
// Align buffers to 16-byte boundaries.
@@ -453,7 +479,7 @@ Status CreateHloProfilingArtifacts(
computation_to_profile_idx,
std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
- *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module);
+ *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
const HloComputation& entry_computation = *module.entry_computation();
TF_ASSIGN_OR_RETURN(
@@ -520,11 +546,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
&pre_optimization_ir_hook, &post_optimization_ir_hook));
// Compile must be thread-safe so create a new LLVM context for the module.
- auto llvm_context = xla::MakeUnique<llvm::LLVMContext>();
+ auto llvm_context = absl::make_unique<llvm::LLVMContext>();
auto llvm_module =
- xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context);
+ absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
- auto jit = xla::MakeUnique<SimpleOrcJIT>(
+ auto jit = absl::make_unique<SimpleOrcJIT>(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
@@ -566,12 +592,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -679,8 +705,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
if (target == nullptr) {
- return InternalError("TargetRegistry::lookupTarget failed: %s",
- error.c_str());
+ return InternalError("TargetRegistry::lookupTarget failed: %s", error);
}
llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
@@ -716,7 +741,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name());
llvm::StringRef features = llvm_ir::AsStringRef(options.features());
llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
- std::unique_ptr<llvm::TargetMachine> target_machine = WrapUnique(
+ std::unique_ptr<llvm::TargetMachine> target_machine = absl::WrapUnique(
target->createTargetMachine(triple.getTriple(), cpu_name, features,
CompilerTargetOptions(modules[0]->config()),
reloc_model, llvm::None, opt_level));
@@ -757,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ absl::make_unique<SequentialHloOrdering>(module, module_sequence),
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
@@ -851,7 +876,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice());
- results.emplace_back(MakeUnique<CpuAotCompilationResult>(
+ results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data)));
}
@@ -874,7 +899,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::host::kHostPlatformId,
- []() { return xla::MakeUnique<xla::cpu::CpuCompiler>(); });
+ []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 04e1c48872..47b5edabff 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler {
Status RunHloPasses(HloModule* module, bool is_aot_compile,
llvm::TargetMachine* target_machine);
+ // Runs HLO passes up to and including layout assignment.
+ Status RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features);
+
+ // Runs HLO passes after layout assignment.
+ Status RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features);
+
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index 3313d1e6eb..d49f7d7cc2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -32,11 +32,11 @@ namespace xla {
// (module-scoped).
class CpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index c376864c3e..08773693fb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -22,6 +22,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -35,9 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
@@ -171,20 +171,18 @@ Status CpuExecutable::ExecuteComputeFunction(
void* result_buffer = buffer_pointers[result_slice.index()];
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
- VLOG(3) << tensorflow::strings::Printf(
- " func(void* result, void* params[null], void* temps[%zu], "
- "uint64 profile_counters[%zu])",
+ VLOG(3) << absl::StrFormat(
+ " func(void* result, void* params[null], void* temps[%u], "
+ "uint64 profile_counters[%u])",
buffer_pointers.size(), profile_counters_size);
- VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
+ VLOG(3) << absl::StrFormat(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
- tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
+ absl::StrAppend(out, absl::StrFormat("%p", p));
};
VLOG(3) << " params = nullptr";
- VLOG(3) << tensorflow::strings::Printf(
- " temps = [%s]",
- tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
- VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
- profile_counters);
+ VLOG(3) << absl::StrFormat(
+ " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer));
+ VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters);
}
compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
index 7bd4741a04..7fbe0fa157 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
@@ -34,9 +34,8 @@ StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
return xla::Unimplemented(
"CPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
- instruction->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction->shape())
- .c_str());
+ instruction->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 2924b63659..6af724b2a5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface {
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "cpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "cpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index b40d264c03..7f867fa149 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
}
if (!CanBeLoopFused(*producer)) {
- VLOG(2) << "Producer is not fusile.";
+ VLOG(2) << "Producer is not fusible.";
return false;
}
@@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
}
if (CanBeLoopFused(*consumer)) {
- VLOG(2) << "Fusing: consumer is elementwise or fusile.";
+ VLOG(2) << "Fusing: consumer is elementwise or fusible.";
return true;
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index e6130c7d76..28aaa28cdb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <set>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
@@ -566,7 +567,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
HloOpcode::kParameter, HloOpcode::kParameter});
}
-TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
+TEST_F(OpcodeFusionTest, MessOfFusibleNodes) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
@@ -773,8 +774,8 @@ class GatherLoopFusionTest
TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
const GatherLoopFusionTestSpec& spec = GetParam();
- string hlo_string = tensorflow::strings::StrCat(
- "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
+ string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
+ spec.hlo_computation_text);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_string));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index aa872d5ec9..bfecbd6e01 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -34,8 +34,8 @@ namespace cpu {
// instruction stream.
namespace {
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
+using absl::nullopt;
+using absl::optional;
using ShouldMakeOperandColMajorCache =
tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index 3ed7876715..b8ace57026 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace {
@@ -45,17 +46,16 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
}
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config) {
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
int64 tiling_factor;
if (it != extra_options_map.end() &&
- tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
+ absl::SimpleAtoi(it->second, &tiling_factor)) {
return tiling_factor;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
@@ -64,38 +64,37 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
}
-static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
- tensorflow::StringPiece suffix) {
+static absl::string_view RemoveSuffix(absl::string_view str,
+ absl::string_view suffix) {
CHECK_GE(str.size(), suffix.size());
CHECK_EQ(str.substr(str.size() - suffix.size()), suffix);
return str.substr(0, str.size() - suffix.size());
}
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrGemmTileSize);
if (it == extra_options_map.end()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
- std::vector<string> tile_components =
- tensorflow::str_util::Split(it->second, ':');
+ std::vector<string> tile_components = absl::StrSplit(it->second, ':');
CHECK_EQ(tile_components.size(), 3);
int64 tile_size_m;
int64 tile_size_k;
int64 tile_size_n_in_vector_width;
- CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m));
- CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k));
+ CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m));
+ CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k));
- tensorflow::StringPiece tile_size_n_in_vector_width_str =
+ absl::string_view tile_size_n_in_vector_width_str =
RemoveSuffix(tile_components[2], "*vectwidth");
- CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str,
- &tile_size_n_in_vector_width));
+ CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str,
+ &tile_size_n_in_vector_width));
return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k,
tile_size_n_in_vector_width);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 429b9e16cb..47c7eb13b6 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,9 +27,8 @@ namespace options {
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config);
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config);
} // namespace options
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index 2ac950e6d9..1ae3aa5711 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -19,16 +19,16 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -46,7 +46,7 @@ std::unique_ptr<Array2D<float>> MaybeTransposeArray2D(const Array2D<T>& array,
if (transpose) {
std::swap(output_width, output_height);
}
- auto output = MakeUnique<Array2D<float>>(output_height, output_width);
+ auto output = absl::make_unique<Array2D<float>>(output_height, output_width);
for (int y = 0; y < array.height(); y++) {
for (int x = 0; x < array.width(); x++) {
if (transpose) {
@@ -93,7 +93,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it. Swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
@@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest,
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
- return tensorflow::strings::Printf(
- "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
- transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
- single_threaded ? "single" : "multi");
+ return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m,
+ shape.k, shape.n, transpose_lhs ? "Tlhs_" : "",
+ transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
}
};
@@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest,
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
- return tensorflow::strings::Printf(
- "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
- transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
- single_threaded ? "single" : "multi");
+ return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m,
+ shape.k, shape.n, transpose_lhs ? "Tlhs_" : "",
+ transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
}
};
@@ -204,7 +204,7 @@ std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it, swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_MKLSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 59bc7e0e16..0df2abf001 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
@@ -103,7 +104,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(
if (ShapeUtil::IsNestedTuple(shape)) {
return Unimplemented(
"Infeed with a nested tuple shape is not supported: %s",
- ShapeUtil::HumanString(literal.shape()).c_str());
+ ShapeUtil::HumanString(literal.shape()));
}
// For a tuple, we transfer each of its elements to the device and
@@ -151,11 +152,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
int64 size,
const void* source) {
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
+ return InvalidArgument("Infeed shape is too large: needs %d bytes", size);
}
if (size <= 0) {
- return InvalidArgument("Infeed shape must have positive size; got %lld",
+ return InvalidArgument("Infeed shape must have positive size; got %d",
size);
}
@@ -243,12 +244,12 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
for (auto b : buffer_data) {
int64 size = b.second;
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Outfeed shape is too large: needs %lld bytes",
+ return InvalidArgument("Outfeed shape is too large: needs %d bytes",
size);
}
if (size <= 0) {
- return InvalidArgument("Outfeed shape must have positive size; got %lld",
+ return InvalidArgument("Outfeed shape must have positive size; got %d",
size);
}
@@ -256,7 +257,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
VLOG(2)
<< "Enqueueing outfeed buffer (for the device to populate) of length "
<< size_32 << "B";
- buffers.emplace_back(MakeUnique<CpuOutfeedBuffer>(b.first, size_32));
+ buffers.emplace_back(absl::make_unique<CpuOutfeedBuffer>(b.first, size_32));
}
std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers;
@@ -283,7 +284,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() {
- return xla::MakeUnique<xla::CpuTransferManager>();
+ return absl::make_unique<xla::CpuTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 80ef953d53..7b938e9fd7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
#include <vector>
@@ -76,4 +76,4 @@ class CpuTransferManager : public GenericTransferManager {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc
index e4c674e227..3ae64142cd 100644
--- a/tensorflow/compiler/xla/service/cpu/disassembler.cc
+++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc
@@ -21,13 +21,13 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/str_format.h"
#include "llvm/MC/MCInst.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -151,7 +151,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
size = 1;
}
- ostream << tensorflow::strings::Printf("0x%08lx", index) << " ";
+ ostream << absl::StrFormat("0x%08lx", index) << " ";
if (decode_status == llvm::MCDisassembler::Success) {
// For branches, try to determine the actual address and emit it as an
@@ -163,7 +163,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
uint64_t target;
if (inst_analysis_->evaluateBranch(
instruction, section_address + index, size, target)) {
- annotation = tensorflow::strings::Printf("[0x%08lx]", target);
+ annotation = absl::StrFormat("[0x%08lx]", target);
}
}
inst_printer_->printInst(&instruction, ostream, annotation.c_str(),
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index f2ac742b6e..dd060f54a2 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -146,9 +147,9 @@ class GemvConfig {
bool has_addend() const { return has_addend_; }
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
- tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+ return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
+ tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
+ has_addend() ? "_with_addend" : "");
}
protected:
@@ -621,19 +622,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
}
// This class implements a tiled matrix multiplication algorithm, intended for
-// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto,
-// Kazushige, and Robert Van De Geijn. "High-performance implementation of the
-// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008):
-// 4).
+// multiplying small matrices that don't need cache tiling.
+//
+// In the future this can be used as the innermost GEBP loop in a GEMM kernel as
+// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
+// high-performance matrix multiplication." ACM Transactions on Mathematical
+// Software (TOMS) 34.3 (2008): 12.".
//
// This only supports canonical dot operations (i.e. where the lhs contraction
// dimension is 1 and the rhs contraction dimension is 0) over row major
// matrices.
-class MatrixMatrixBlockPanelEmitter {
+class TiledSmallGemmEmitter {
public:
- // Describe the dimensions of the GEBP kernel. These will usually not be the
- // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP
- // kernels with smaller dimensions.
+ // Describe the dimensions of the kernel.
class Dimensions {
public:
explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
@@ -642,9 +643,7 @@ class MatrixMatrixBlockPanelEmitter {
int64 k() const { return k_; }
int64 n() const { return n_; }
- string ToString() const {
- return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
- }
+ string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
private:
const int64 m_;
@@ -652,9 +651,9 @@ class MatrixMatrixBlockPanelEmitter {
const int64 n_;
};
- // Represents the configuration of the GEBP emitter. The LLVM IR emitted by
- // the emitter, modulo the LLVM values holding the input and output buffers,
- // must be a function of the instance of `Config` passed to it.
+ // Represents the configuration of the emitter. The LLVM IR emitted by the
+ // emitter, modulo the LLVM values holding the input and output buffers, must
+ // be a function of the instance of `Config` passed to it.
//
// `dims` holds the matrix multiplication dimensions.
//
@@ -687,10 +686,10 @@ class MatrixMatrixBlockPanelEmitter {
tile_size_k_(tile_size_k) {}
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
- "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
- tile_size_m(), "_", tile_size_k());
+ return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
+ dims().ToString(), "_", max_vectorization_width(),
+ "_", min_vectorization_width(), "_", tile_size_m(),
+ "_", tile_size_k());
}
PrimitiveType scalar_type() const { return scalar_type_; }
@@ -712,11 +711,11 @@ class MatrixMatrixBlockPanelEmitter {
int64 tile_size_k_;
};
- // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+ // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
// `lhs` with `rhs` and stores the result in `result`.
- explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
- llvm::Value* rhs, llvm::Value* result,
- llvm::IRBuilder<>* b)
+ explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* result,
+ llvm::IRBuilder<>* b)
: lhs_(lhs),
rhs_(rhs),
result_(result),
@@ -780,9 +779,9 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
+void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
+void TiledSmallGemmEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
@@ -799,7 +798,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
- "gebp");
+ "gemm");
HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
@@ -813,7 +812,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
if (n_start != dims().n()) {
- VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp");
+ VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
HandleResiduesOnK(&vsl, n_i, n_i_next);
@@ -821,9 +820,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
}
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
- llvm::Value* n_start,
- llvm::Value* n_end) {
+void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
@@ -838,7 +837,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
}
}
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
+void TiledSmallGemmEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();
@@ -921,7 +920,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
// +-------------------+-------------------+-------------------+---------
// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
// +-------------------+-------------------+-------------------+---------
-void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
+void TiledSmallGemmEmitter::EmitTiledGemm(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
@@ -1001,12 +1000,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
return dot_emitter.Emit();
}
-bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
+bool DotOpEmitter::EmitSmallGemmIfProfitable(
const DotOpEmitter::MatMultDims& mat_mult_dims) {
- if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) {
+ if (ShouldUseMultiThreadedEigen()) {
return false;
}
+ if (!EnableExperimentalLlvmIrGemm()) {
+ // TODO(sanjoy): We should make these numbers micro-arch specific.
+ bool small_gemm = mat_mult_dims.k <= 128 &&
+ ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) ||
+ (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32));
+ if (!small_gemm) {
+ return false;
+ }
+ }
+
if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) {
return false;
}
@@ -1054,15 +1063,15 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
GetGemmTileSize();
- MatrixMatrixBlockPanelEmitter::Config config(
+ TiledSmallGemmEmitter::Config config(
/*scalar_type=*/primitive_type,
- MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
+ TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
/*max_vectorization_width=*/max_target_vector_width,
/*max_vector_count=*/tile_size_n_in_vector_width,
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
- VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
+ VLOG(2) << "Emitting GEMM kernel in LLVM IR with config "
<< config.GetCacheKey();
const bool enable_fast_math =
@@ -1075,10 +1084,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
/*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs,
rhs, target,
[this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
- MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs,
- /*rhs=*/rhs,
- /*result=*/target, b_);
- gebp_emitter.Emit();
+ TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
+ /*rhs=*/rhs,
+ /*result=*/target, b_);
+ small_gemm_emitter.Emit();
});
return true;
@@ -1136,7 +1145,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
}
if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
- return EmitExperimentalGebpDotIfEnabled(mat_mult_dims);
+ return EmitSmallGemmIfProfitable(mat_mult_dims);
}
int64 tiling_factor = GetGemvTilingFactor();
@@ -1458,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
break;
default:
return Unimplemented("Invalid type %s for dot operation",
- PrimitiveType_Name(type).c_str());
+ PrimitiveType_Name(type));
}
llvm::Type* float_ptr_type = float_type->getPointerTo();
@@ -1610,7 +1619,7 @@ bool PotentiallyImplementedAsEigenDot(
// For vector-matrix dot products, it is always profitable to make the Rhs
// column major.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
hlo.shape().dimensions(0) == 1) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 590032fbe9..4c2041b556 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
+#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot(
// Returns the index for an operand to `hlo` that should ideally be column
// major. Returns nullopt if there is no such operand or if `hlo` is not a dot
// or a fusion containing a dot.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
@@ -121,7 +121,7 @@ class DotOpEmitter {
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
- bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims);
+ bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims);
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index db54454707..c8312d80bd 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -30,15 +30,16 @@ limitations under the License.
namespace xla {
namespace cpu {
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- lhs = b_->CreateFPCast(lhs, b_->getFloatTy());
- rhs = b_->CreateFPCast(rhs, b_->getFloatTy());
+ lhs = FPCast(lhs, b_->getFloatTy());
+ rhs = FPCast(rhs, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@@ -58,21 +59,21 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, {lhs, rhs});
+ llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- value = b_->CreateFPCast(value, b_->getFloatTy());
+ value = FPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@@ -91,16 +92,16 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, value);
+ llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
if (hlo->opcode() == HloOpcode::kMap) {
return [this, hlo, &operand_to_generator](
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 76833e765d..e3fba9306b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6f433b4f30..460363e18f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
@@ -67,8 +69,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -170,9 +170,9 @@ IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] =
- b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
- IrShapeType(bitcast->shape())->getPointerTo(),
- AsStringRef(IrName(bitcast)));
+ BitCast(GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(),
+ AsStringRef(IrName(bitcast)));
return Status::OK();
}
@@ -230,9 +230,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
// Use the elemental emitter for array shapes.
return DefaultAction(copy);
}
- return Unimplemented(
- "unsupported operand type %s for copy instruction",
- PrimitiveType_Name(copy->shape().element_type()).c_str());
+ return Unimplemented("unsupported operand type %s for copy instruction",
+ PrimitiveType_Name(copy->shape().element_type()));
}
// Calculate the alignment of a buffer allocated for a given primitive type.
@@ -389,7 +388,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
int64 length = ByteSizeOf(shape);
if (length <= 0 || length > std::numeric_limits<int32>::max()) {
return InvalidArgument(
- "xfeed (infeed or outfeed) buffer length %lld is outside the valid "
+ "xfeed (infeed or outfeed) buffer length %d is outside the valid "
"size range",
length);
}
@@ -440,22 +439,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer = b_.CreateCall(
- acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer =
+ Call(acquire_func,
+ {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
- /*SrcAlign=*/1, length_32);
+ MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
- /*SrcAlign=*/1, length_32);
+ MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
- b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer,
- shape_ptr, b_.getInt32(shape_length)});
+ Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
+ b_.getInt32(shape_length)});
return Status::OK();
}
@@ -502,7 +501,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
@@ -519,8 +518,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accumulator_address", &b_,
MinimumAlignmentForPrimitiveType(operand_element_type));
- b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
- accumulator_address);
+ Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
std::vector<int64> window_size;
@@ -537,22 +536,21 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm::Value* in_bounds_condition = nullptr;
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
- b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
// input_index[i] < bound as an *unsigned* comparison, since a negative
// value will wrap to a large positive value.
- llvm::Value* index_condition = b_.CreateICmpULT(
- input_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ llvm::Value* index_condition =
+ ICmpULT(input_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
if (in_bounds_condition == nullptr) {
in_bounds_condition = index_condition;
} else {
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
}
CHECK(in_bounds_condition != nullptr);
@@ -565,12 +563,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
- b_.CreateStore(result, accumulator_address);
+ *reduce_window->to_apply(), {Load(accumulator_address), input_value},
+ "reducer_function");
+ Store(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_address);
+ return Load(accumulator_address);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
@@ -647,7 +645,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- return b_.CreateLoad(init_value_addr);
+ return Load(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
@@ -667,7 +665,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
@@ -685,15 +683,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size());
llvm::Value* in_bounds_condition = b_.getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
- source_index[i], b_.getInt64(window.dimensions(i).stride()));
- operand_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
- operand_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ llvm::Value* strided_index =
+ NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition =
+ ICmpULT(operand_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -703,7 +700,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -712,38 +709,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
[&](const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* operand_element = Load(operand_address);
llvm::Value* result = EmitThreadLocalCall(
*select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
- "select_function");
+ {Load(selected_value_address), operand_element}, "select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -754,8 +750,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
llvm::Value* source_value =
@@ -837,7 +833,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
lhs_llvm_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
- b_.CreateStore(constant_zero, sum_address);
+ Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
@@ -846,7 +842,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
loops
.AddLoop(
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
+ absl::StrCat("k", i))
->GetIndVarValue();
}
llvm::Value* input_feature =
@@ -864,11 +860,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
- b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = b_.CreateNSWMul(
- kernel_index, b_.getInt64(window_dim.window_dilation()));
- return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index),
- b_.getInt64(window_dim.padding_low()));
+ NSWMul(output_index, b_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index =
+ NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
+ return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
+ b_.getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -885,9 +881,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
- llvm::Value* remainder =
- b_.CreateSRem(input_index, b_.getInt64(base_dilation));
- return b_.CreateICmpEQ(remainder, b_.getInt64(0));
+ llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
+ return ICmpEQ(remainder, b_.getInt64(0));
};
llvm::Value* in_bounds_condition = b_.getInt1(true);
@@ -895,17 +890,17 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok);
+ llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
- return b_.CreateSDiv(input_index, b_.getInt64(base_dilation));
+ return SDiv(input_index, b_.getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
@@ -930,8 +925,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
- ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
+ ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
: kernel_spatial[i];
}
@@ -940,13 +935,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
llvm::Value* product =
- b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_),
- kernel_array.EmitReadArrayElement(kernel_index, &b_));
- llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product);
- b_.CreateStore(sum, sum_address);
+ FMul(input_array.EmitReadArrayElement(input_index, &b_),
+ kernel_array.EmitReadArrayElement(kernel_index, &b_));
+ llvm::Value* sum = FAdd(Load(sum_address), product);
+ Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(sum_address);
+ return Load(sum_address);
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
@@ -1072,34 +1067,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
- b_.CreateCall(
- conv_func,
- {
- GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type),
- b_.CreateBitCast(lhs_address, ir_ptr_type),
- b_.CreateBitCast(rhs_address, ir_ptr_type),
- b_.getInt64(input_batch),
- b_.getInt64(input_rows),
- b_.getInt64(input_cols),
- b_.getInt64(input_channels),
- b_.getInt64(kernel_rows),
- b_.getInt64(kernel_cols),
- b_.getInt64(kernel_channels),
- b_.getInt64(kernel_filters),
- b_.getInt64(output_rows),
- b_.getInt64(output_cols),
- b_.getInt64(row_stride),
- b_.getInt64(col_stride),
- b_.getInt64(padding_top),
- b_.getInt64(padding_bottom),
- b_.getInt64(padding_left),
- b_.getInt64(padding_right),
- b_.getInt64(lhs_row_dilation),
- b_.getInt64(lhs_col_dilation),
- b_.getInt64(rhs_row_dilation),
- b_.getInt64(rhs_col_dilation),
- });
+ Call(conv_func, {
+ GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
+ BitCast(lhs_address, ir_ptr_type),
+ BitCast(rhs_address, ir_ptr_type),
+ b_.getInt64(input_batch),
+ b_.getInt64(input_rows),
+ b_.getInt64(input_cols),
+ b_.getInt64(input_channels),
+ b_.getInt64(kernel_rows),
+ b_.getInt64(kernel_cols),
+ b_.getInt64(kernel_channels),
+ b_.getInt64(kernel_filters),
+ b_.getInt64(output_rows),
+ b_.getInt64(output_cols),
+ b_.getInt64(row_stride),
+ b_.getInt64(col_stride),
+ b_.getInt64(padding_top),
+ b_.getInt64(padding_bottom),
+ b_.getInt64(padding_left),
+ b_.getInt64(padding_right),
+ b_.getInt64(lhs_row_dilation),
+ b_.getInt64(lhs_col_dilation),
+ b_.getInt64(rhs_row_dilation),
+ b_.getInt64(rhs_col_dilation),
+ });
return Status::OK();
}
@@ -1159,15 +1152,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
const int fft_rank = fft_length.size();
- b_.CreateCall(
- fft_func,
- {GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
- b_.CreateBitCast(operand_address, int8_ptr_type),
- b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank),
- b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
- b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
- b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+ Call(fft_func,
+ {GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
+ b_.getInt32(fft_rank), b_.getInt64(input_batch),
+ b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
return Status::OK();
}
@@ -1206,8 +1198,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
- /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
+ MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
return Status::OK();
@@ -1466,19 +1458,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
accumulator_shard_type, "accumulator", &b_, 0));
}
- llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value));
+ llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
for (llvm::Value* accumulator_shard : accumulator) {
llvm::Value* initial_value;
auto shard_type = accumulator_shard->getType()->getPointerElementType();
if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
initial_value =
- b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa);
+ VectorSplat(vector_type->getNumElements(), init_value_ssa);
} else {
initial_value = init_value_ssa;
}
- b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment);
+ AlignedStore(initial_value, accumulator_shard, element_alignment);
}
llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
@@ -1500,24 +1492,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
}
CHECK(output_index.end() == it);
- llvm::Value* input_address = b_.CreateBitCast(
+ llvm::Value* input_address = BitCast(
arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
for (int i = 0; i < accumulator.size(); i++) {
auto input_address_typed =
- b_.CreateBitCast(input_address, accumulator[i]->getType());
+ BitCast(input_address, accumulator[i]->getType());
auto current_accumulator_value =
- b_.CreateAlignedLoad(accumulator[i], element_alignment);
- auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment);
+ AlignedLoad(accumulator[i], element_alignment);
+ auto addend = AlignedLoad(input_address_typed, element_alignment);
arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
auto reduced_result =
reduction_generator(&b_, current_accumulator_value, addend);
- b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment);
+ AlignedStore(reduced_result, accumulator[i], element_alignment);
if (i != (accumulator.size() - 1)) {
- input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(),
- input_address_typed, 1);
+ input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
+ input_address_typed, 1);
}
}
@@ -1526,8 +1518,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
ShardedVector result_ssa;
result_ssa.reserve(accumulator.size());
for (auto accumulator_shard : accumulator) {
- result_ssa.push_back(
- b_.CreateAlignedLoad(accumulator_shard, element_alignment));
+ result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
}
return result_ssa;
}
@@ -1536,18 +1527,18 @@ void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
const int alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
- auto store_address_typed = b_.CreateBitCast(
- store_address,
- llvm::PointerType::getUnqual(value_to_store[i]->getType()));
+ auto store_address_typed =
+ BitCast(store_address,
+ llvm::PointerType::getUnqual(value_to_store[i]->getType()));
- auto store_instruction = b_.CreateAlignedStore(
- value_to_store[i], store_address_typed, alignment);
+ auto store_instruction =
+ AlignedStore(value_to_store[i], store_address_typed, alignment);
containing_array.AnnotateLoadStoreInstructionWithMetadata(
store_instruction);
if (i != (value_to_store.size() - 1)) {
- store_address = b_.CreateConstInBoundsGEP1_32(
- value_to_store[i]->getType(), store_address_typed, 1);
+ store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
+ store_address_typed, 1);
}
}
}
@@ -1620,9 +1611,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
int64 start_index = 0;
int64 end_index = reduce->shape().dimensions(dimension);
- std::unique_ptr<llvm_ir::ForLoop> loop =
- loop_nest.AddLoop(start_index, end_index,
- tensorflow::strings::Printf("dim.%lld", dimension));
+ std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
+ start_index, end_index, absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
@@ -1641,9 +1631,9 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
int64 start_index = 0;
int64 end_index = (innermost_dimension_size / vectorization_factor) *
vectorization_factor;
- std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- start_index, end_index, vectorization_factor,
- tensorflow::strings::Printf("dim.%lld", innermost_dimension));
+ std::unique_ptr<llvm_ir::ForLoop> loop =
+ loop_nest.AddLoop(start_index, end_index, vectorization_factor,
+ absl::StrFormat("dim.%d", innermost_dimension));
array_index[innermost_dimension] = loop->GetIndVarValue();
SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
@@ -1713,8 +1703,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = b_.CreateLoad(init_value_addr);
- b_.CreateStore(load_init_value, accumulator_addr);
+ llvm::Value* load_init_value = Load(init_value_addr);
+ Store(load_init_value, accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
@@ -1747,12 +1737,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
// Apply the reduction function to the loaded value.
llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ *reduce->to_apply(), {Load(accumulator_addr), input_element},
"reduce_function");
- b_.CreateStore(result, accumulator_addr);
+ Store(result, accumulator_addr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
@@ -1990,7 +1980,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
[this, pad](const llvm_ir::IrArray::Index& target_index) {
const HloInstruction* padding_value = pad->operand(1);
llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
- return b_.CreateLoad(padding_value_addr);
+ return Load(padding_value_addr);
}));
// Create a loop to iterate over the operand elements and update the output
@@ -2012,10 +2002,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
- llvm::Value* offset = b_.CreateMul(
- operand_index[i],
- b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
- llvm::Value* index = b_.CreateAdd(
+ llvm::Value* offset =
+ Mul(operand_index[i],
+ b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
+ llvm::Value* index = Add(
offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
output_index.push_back(index);
}
@@ -2118,7 +2108,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
- tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
+ absl::string_view custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
@@ -2126,10 +2116,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
- b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
llvm::Value* slot_in_operands_alloca =
- b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)});
- b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ InBoundsGEP(operands_alloca, {b_.getInt64(i)});
+ Store(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
@@ -2141,9 +2131,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
auto* output_address_arg =
- b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
- b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca});
+ Call(custom_call_ir_function, {output_address_arg, operands_alloca});
return Status::OK();
}
@@ -2170,8 +2160,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
return InternalError(
"instruction %s %s does not share slice with "
"instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
+ a->ToString(), slice_a.ToString(), b->ToString(),
+ slice_b.ToString());
}
return Status::OK();
};
@@ -2202,15 +2192,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
compute_function_->function());
- b_.CreateBr(header_bb);
+ Br(header_bb);
b_.SetInsertPoint(header_bb);
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
- llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ llvm::Value* while_predicate = ICmpNE(
+ Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2219,7 +2208,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
- b_.CreateCondBr(while_predicate, body_bb, exit_bb);
+ CondBr(while_predicate, body_bb, exit_bb);
// Calls the body function from the body block.
b_.SetInsertPoint(body_bb);
@@ -2228,7 +2217,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
// Finishes with a branch back to the header.
- b_.CreateBr(header_bb);
+ Br(header_bb);
// Adds the exit block to the function and sets the insert point there.
compute_function_->function()->getBasicBlockList().push_back(exit_bb);
@@ -2275,7 +2264,6 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
output_min2maj.end());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
- llvm::Type* i8_type = b_.getInt8Ty();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
@@ -2298,9 +2286,9 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin =
- b_.CreateBitCast(target_array.EmitArrayElementAddress(
- outer_dims_index, &b_, "target_region"),
- i8_ptr_type);
+ BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_,
+ "target_region"),
+ i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
@@ -2314,13 +2302,12 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
- llvm::Value* copy_source_address = b_.CreateBitCast(
+ llvm::Value* copy_source_address = BitCast(
source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"),
i8_ptr_type);
llvm::Value* copy_target_address =
- b_.CreateGEP(i8_type, target_region_begin,
- b_.getInt64(byte_offset_into_target_region));
+ GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
@@ -2352,15 +2339,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
- auto* load_instruction = b_.CreateAlignedLoad(
- b_.CreateBitCast(source, primitive_ptr_type), element_alignment);
+ auto* load_instruction =
+ AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
- auto* store_instruction = b_.CreateAlignedStore(
- load_instruction, b_.CreateBitCast(target, primitive_ptr_type),
- element_alignment);
+ auto* store_instruction =
+ AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
+ element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
- auto* memcpy_instruction = b_.CreateMemCpy(
+ auto* memcpy_instruction = MemCpy(
target, /*DstAlign=*/element_alignment, source,
/*SrcAlign=*/element_alignment, element_count * primitive_type_size);
@@ -2422,9 +2409,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
- llvm::LoadInst* pred_value = b_.CreateLoad(
- GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = b_.CreateICmpNE(
+ llvm::LoadInst* pred_value =
+ Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = ICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
@@ -2450,11 +2437,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
return Status::OK();
}
-Status IrEmitter::HandleIota(HloInstruction* iota) {
- // TODO(b/64798317): implement iota on CPU.
- return Unimplemented("Iota is not implemented on CPU.");
-}
-
Status IrEmitter::HandleRng(HloInstruction* rng) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : rng->operands()) {
@@ -2511,8 +2493,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon(
int64 prof_counter_idx = it->second;
string counter_name = IrName("prof_counter", hlo.name());
- return b_.CreateGEP(GetProfileCountersArgument(),
- b_.getInt64(prof_counter_idx), AsStringRef(counter_name));
+ return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
+ AsStringRef(counter_name));
}
void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
@@ -2666,8 +2648,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
+ llvm::LoadInst* param_address_untyped = Load(param_address_offset);
if (!ShapeUtil::IsOpaque(target_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
@@ -2687,17 +2668,15 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
auto buf_it = thread_local_buffers_.find(key);
if (buf_it == thread_local_buffers_.end()) {
llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
- IrShapeType(shape),
- tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
- MinimumAlignmentForShape(target_shape));
+ IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
+ &b_, MinimumAlignmentForShape(target_shape));
auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
CHECK(it_inserted_pair.second);
buf_it = it_inserted_pair.first;
}
return buf_it->second;
}();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
@@ -2705,7 +2684,7 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
- llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
+ llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
tempbuf_address_base->setMetadata(
@@ -2719,10 +2698,10 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
if (slice.offset() > 0) {
// Adjust the address to account for the slice offset.
tempbuf_address_untyped =
- b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
+ InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
- return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address_untyped,
+ IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitTempBufferPointer(
@@ -2753,7 +2732,7 @@ Status IrEmitter::EmitTargetElementLoop(
}
Status IrEmitter::EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator) {
VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
@@ -2808,8 +2787,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
- /*SrcAlign=*/1, source_size);
+ MemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
@@ -2827,8 +2806,8 @@ Status IrEmitter::ElementTypesSameAndSupported(
if (std::find(supported_types.begin(), supported_types.end(),
primitive_type) == supported_types.end()) {
return Unimplemented("unsupported operand type %s in op %s",
- PrimitiveType_Name(primitive_type).c_str(),
- HloOpcodeString(instruction.opcode()).c_str());
+ PrimitiveType_Name(primitive_type),
+ HloOpcodeString(instruction.opcode()));
}
return Status::OK();
}
@@ -2848,7 +2827,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
llvm::Value* IrEmitter::EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
@@ -2863,38 +2842,37 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
CHECK(!parameter->getType()->isPointerTy());
llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
+ Store(parameter, parameter_addr);
parameter_addrs.push_back(parameter_addr);
}
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
- tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
- return b_.CreateLoad(return_value_buffer);
+ return Load(return_value_buffer);
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ absl::string_view name) {
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index c9a1dab62d..f98891246b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -39,12 +40,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
@@ -55,7 +56,8 @@ namespace cpu {
// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
// functions.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
// Create a new LLVM IR emitter.
//
@@ -100,6 +102,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return &b_; }
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
@@ -107,7 +112,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name);
+ absl::string_view name);
protected:
//
@@ -152,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleConditional(HloInstruction* conditional) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
- Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
Status FinishVisit(HloInstruction* root) override;
@@ -239,7 +243,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// function that a map operation applies.
StatusOr<llvm::Function*> EmitFunction(
HloComputation* function, // The function to emit.
- tensorflow::StringPiece
+ absl::string_view
function_name_suffix); // Used for LLVM IR register names.
// Emits a call to a thread local function (e.g. to the computation nested
@@ -251,14 +255,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name);
+ absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
// the parameters and return values for these computations so there is no need
// to explicitly pass parameters or return results.
- void EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name);
+ void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
// Returns the buffer to which a global call to `callee` would have written
// its result.
@@ -285,7 +288,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* target_op,
const llvm_ir::ElementGenerator& element_generator);
Status EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator);
// Emits a memcpy from the source instruction's result value to the
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 2db4d000f5..784045313d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -189,7 +190,7 @@ void IrFunction::Initialize(const string& function_name,
llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
CHECK_GT(num_dynamic_loop_bounds_, 0);
CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
- string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
+ string name = absl::StrCat("dynamic_loop_bound_", offset);
return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
b_->getInt64(offset), AsStringRef(name)));
}
@@ -200,7 +201,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* b, absl::string_view name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
@@ -211,13 +212,13 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
} else {
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+ absl::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr =
b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
+ AsStringRef(absl::StrCat(name, "_parameter_", i,
+ "_address_as_i8ptr")));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
@@ -320,8 +321,7 @@ Status EmitCallToParallelForkJoin(
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/partitions_array,
/*Name=*/
- AsStringRef(
- tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
+ AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions")));
// Add argument specifying parallel dimension partitions.
fork_join_arguments.push_back(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index a41cbb64cd..ee7595f6e9 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -116,7 +116,7 @@ class IrFunction {
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* b, absl::string_view name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 8560e4296a..f8441c3e34 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace cpu {
@@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
dynamic_loop_bounds_(dynamic_loop_bounds) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
CHECK(!ShapeUtil::IsTuple(shape_));
@@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second;
std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension),
- start_index, end_index);
+ /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index,
+ end_index);
array_index[dimension] = loop->GetIndVarValue();
} else {
// Emit static loop bounds for this dimension.
std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/shape_.dimensions(dimension),
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ /*suffix=*/absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 076c683ca5..a604e1db22 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
const DynamicLoopBounds* dynamic_loop_bounds_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 4fa5984b04..b4c0c09ec0 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
@@ -109,7 +111,7 @@ ParallelTaskAssignment::ParallelTaskAssignment(
: target_machine_features_(*target_machine_features) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
- auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
+ auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
HloComputation* computation = module->entry_computation();
Status status = computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
@@ -216,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
// Outline 'instruction' in 'computation' for parallel task assignment.
auto* call = module->OutlineExpressionFromComputation(
- {instruction},
- tensorflow::strings::StrCat("parallel_", instruction->name()),
+ {instruction}, absl::StrCat("parallel_", instruction->name()),
computation);
// Set assigned dimension partitioning to 'instruction'.
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 8becc8fa23..a99cd99c14 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface {
target_machine_features_(*target_machine_features) {}
~ParallelTaskAssigner() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cpu-parallel-task-assigner";
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index ee272b5f4f..a84ee78b19 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index f227e4ae13..942e2ddd39 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -67,8 +67,8 @@ int main(int argc, char** argv) {
/*execution_profile=*/&profile);
std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
- LOG(INFO) << tensorflow::strings::Printf("computation took %lldns",
- profile.compute_time_ns());
+ LOG(INFO) << absl::StrFormat("computation took %dns",
+ profile.compute_time_ns());
LOG(INFO) << actual->ToString();
return 0;
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index be772cfb7e..bf98064647 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include <list>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/Mangler.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Host.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
@@ -170,15 +170,14 @@ namespace {
bool RegisterKnownJITSymbols() {
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
-#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
- do { \
- auto* function_address = \
- reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
- registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
- function_address); \
- CHECK_EQ( \
- tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \
- "__xla_cpu_runtime_" #base_name); \
+#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
+ do { \
+ auto* function_address = \
+ reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
+ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
+ function_address); \
+ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
+ "__xla_cpu_runtime_" #base_name); \
} while (false)
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 181cec3cdd..2384166fd2 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -51,6 +51,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +95,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
)
@@ -108,6 +110,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -121,6 +124,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index 6fcce42eaa..fcd87b36b3 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index d98856fdbf..22721051e5 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -129,8 +129,8 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
error_spec_);
}
-TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
- // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the
+TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
+ // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the
// middle.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
index 973aac8766..a434c04a98 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android";
struct IntrinsicTestSpec {
HloOpcode opcode;
- tensorflow::StringPiece triple;
- tensorflow::StringPiece features;
- tensorflow::StringPiece check_lines;
+ absl::string_view triple;
+ absl::string_view features;
+ absl::string_view check_lines;
};
// Tests that unary functions get lowered using intrinsic calls.
@@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest
features = "";
}
- return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(),
- features.empty() ? "" : "_With",
- features.c_str());
+ return absl::StrCat(opcode, "_On_", triple,
+ (features.empty() ? "" : "_With"), features);
}
};
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 01daed4bcd..bb105194f1 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) {
// Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it.
auto status_or_buffer_assn = BufferAssigner::Run(
- hlo_module.get(), MakeUnique<DependencyHloOrdering>(hlo_module.get()),
+ hlo_module.get(),
+ absl::make_unique<DependencyHloOrdering>(hlo_module.get()),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return /*alignment=*/1; });
ASSERT_EQ(status_or_buffer_assn.status(), Status::OK());
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 3274be8d9d..962ea69c09 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "absl/algorithm/container.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -422,8 +423,8 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support,
std::vector<llvm::Value*> TileVariable::Get() const {
std::vector<llvm::Value*> result;
- c_transform(storage_, std::back_inserter(result),
- [&](VectorVariable vect_var) { return vect_var.Get(); });
+ absl::c_transform(storage_, std::back_inserter(result),
+ [&](VectorVariable vect_var) { return vect_var.Get(); });
return result;
}
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index 56b28fd22d..c326beb899 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -29,7 +29,7 @@ class Defuser : public HloPassInterface {
public:
Defuser() {}
~Defuser() override {}
- tensorflow::StringPiece name() const override { return "defuser"; }
+ absl::string_view name() const override { return "defuser"; }
// Run defusion on the given module. Returns whether the module was
// changed.
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index e727ba49cb..37d1895d41 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
+ public:
+ DefuserTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index 48e4471499..ba2a674d9a 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -27,9 +27,7 @@ namespace {
class ControlDepRemover : public HloPassInterface {
public:
ControlDepRemover() = default;
- tensorflow::StringPiece name() const override {
- return "control-dep-remover";
- }
+ absl::string_view name() const override { return "control-dep-remover"; }
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index cc1695b7f8..7be70add2f 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -33,7 +33,7 @@ namespace xla {
class Despecializer : public HloPassInterface {
public:
Despecializer();
- tensorflow::StringPiece name() const override { return "despecializer"; }
+ absl::string_view name() const override { return "despecializer"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
index e228bb56bc..1d0297cfbf 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.cc
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -36,9 +36,8 @@ StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
se::DeviceMemoryBase result = stream_executor->AllocateArray<uint8>(size);
if (size > 0 && result == nullptr) {
return ResourceExhausted(
- "Failed to allocate request for %s (%lluB) on device ordinal %d",
- tensorflow::strings::HumanReadableNumBytes(size).c_str(), size,
- device_ordinal);
+ "Failed to allocate request for %s (%uB) on device ordinal %d",
+ tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal);
}
return OwningDeviceMemory(result, device_ordinal, this);
}
@@ -61,12 +60,12 @@ StatusOr<se::StreamExecutor*> StreamExecutorMemoryAllocator::GetStreamExecutor(
}
if (device_ordinal >= stream_executors_.size()) {
return InvalidArgument(
- "device ordinal value (%d) >= number of devices (%zu)", device_ordinal,
+ "device ordinal value (%d) >= number of devices (%u)", device_ordinal,
stream_executors_.size());
}
if (stream_executors_[device_ordinal] == nullptr) {
return NotFound("Device %s:%d present but not supported",
- platform()->Name().c_str(), device_ordinal);
+ platform()->Name(), device_ordinal);
}
return stream_executors_[device_ordinal];
}
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
index 2172ae0a29..3e7373adc5 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
@@ -28,14 +28,14 @@ template <typename HloInstructionPtr>
Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseUnary(
HloInstructionPtr hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
}
template <typename HloInstructionPtr>
Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseBinary(
HloInstructionPtr hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
}
template <typename HloInstructionPtr>
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 86d57581f8..f6f8fc5a2a 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
@@ -107,6 +107,7 @@ class DfsHloVisitorBase {
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
+ virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
@@ -208,7 +209,6 @@ class DfsHloVisitorBase {
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
- virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
virtual Status HandleSort(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 617a5a2eb4..4f620e4c3a 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -16,13 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
- Status HandleAllToAll(HloInstructionPtr crs) override {
- return DefaultAction(crs);
+ Status HandleAllToAll(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
+ }
+ Status HandleCollectivePermute(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
}
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
@@ -106,9 +109,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleOutfeed(HloInstructionPtr outfeed) override {
return DefaultAction(outfeed);
}
- Status HandleHostCompute(HloInstructionPtr host_compute) override {
- return DefaultAction(host_compute);
- }
Status HandleReverse(HloInstructionPtr reverse) override {
return DefaultAction(reverse);
}
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 12faed6967..09cb10d6ee 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -136,6 +136,7 @@ Status DecomposeBatchDot(HloInstruction* dot) {
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
+ dot_r2->set_precision_config(dot->precision_config());
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index 1959b687f1..fc38e31700 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface {
DotDecomposer(bool decompose_batch_dot = true)
: decompose_batch_dot_(decompose_batch_dot) {}
~DotDecomposer() = default;
- tensorflow::StringPiece name() const override { return "dot_decomposer"; }
+ absl::string_view name() const override { return "dot_decomposer"; }
// Run DotDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 891ae42141..813e93fafa 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -21,11 +21,15 @@ limitations under the License.
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -38,17 +42,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+using absl::StrCat;
using llvm_ir::AsStringRef;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrCat;
namespace {
@@ -203,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
} // namespace
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
} else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
@@ -217,7 +220,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -229,14 +232,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateICmpNE(operand_value, llvm::ConstantInt::get(
- operand_value->getType(), 0)),
+ ICmpNE(operand_value,
+ llvm::ConstantInt::get(operand_value->getType(), 0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsIntegralType(to_type)) {
- return b_->CreateIntCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
- primitive_util::IsSignedIntegralType(from_type));
+ return IntCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_),
+ primitive_util::IsSignedIntegralType(from_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (to_type == BF16) {
@@ -252,19 +255,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return EmitComposeComplex(
- op, b_->CreateSIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, SIToFP(operand_value, to_ir_component_type), nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return EmitComposeComplex(
- op, b_->CreateUIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, UIToFP(operand_value, to_ir_component_type), nullptr);
}
}
return Unimplemented("conversion from primitive type %s to %s",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str());
+ PrimitiveType_Name(from_type),
+ PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -275,14 +276,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str(),
+ PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
@@ -292,10 +292,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
if (is_signed) {
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = b_->CreateICmpSGE(operand_value, zero);
- return b_->CreateSelect(cmp, operand_value,
- b_->CreateNeg(operand_value));
+ auto cmp = ICmpSGE(operand_value, GetZero(type));
+ return Select(cmp, operand_value, Neg(operand_value));
} else {
return operand_value;
}
@@ -307,44 +305,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
{operand_value->getType()}, b_);
}
case HloOpcode::kSign: {
- bool is_signed =
- primitive_util::IsSignedIntegralType(op->shape().element_type());
+ CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
+ << op->shape().element_type();
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = b_->CreateICmpEQ(operand_value, zero);
- if (is_signed) {
- auto ashr =
- b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
- return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1));
- } else {
- return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1));
- }
+ auto cmp = ICmpEQ(operand_value, GetZero(type));
+ auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
+ return Select(cmp, GetZero(type), Or(ashr, 1));
}
case HloOpcode::kNegate:
- return b_->CreateNeg(operand_value);
+ return Neg(operand_value);
case HloOpcode::kNot: {
auto type = op->shape().element_type();
if (type == PRED) {
// It is not sufficient to just call CreateNot() here because a PRED
// is represented as an i8 and the truth value is stored only in the
// bottom bit.
- return b_->CreateZExt(
- b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())),
- llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
- return b_->CreateNot(operand_value);
+ return Not(operand_value);
}
return Unimplemented("unary op Not is not defined for type '%d'", type);
}
default:
return Unimplemented("unary integer op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -361,8 +352,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
return EmitComposeComplex(
op,
- b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType(
- to_component_type, module_)),
+ FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
if (from_type == BF16) {
@@ -378,26 +369,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateFCmpUNE(
- operand_value,
- llvm::ConstantFP::get(operand_value->getType(), 0.0)),
+ FCmpUNE(operand_value,
+ llvm::ConstantFP::get(operand_value->getType(), 0.0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- return b_->CreateFPCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
- return b_->CreateFPToSI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToSI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
- return b_->CreateFPToUI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToUI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str());
+ PrimitiveType_Name(from_type),
+ PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -408,14 +398,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str(),
+ PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
@@ -453,11 +442,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(operand_value, zero);
- auto olt = b_->CreateFCmpOLT(operand_value, zero);
- return b_->CreateSelect(
- oeq, zero,
- b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
+ auto oeq = FCmpOEQ(operand_value, zero);
+ auto olt = FCmpOLT(operand_value, zero);
+ return Select(oeq, zero,
+ Select(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
@@ -467,24 +455,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
auto infinity = llvm::ConstantFP::getInfinity(type);
- auto not_infinite = b_->CreateFCmpONE(abs_value, infinity);
+ auto not_infinite = FCmpONE(abs_value, infinity);
return b_->CreateZExt(not_infinite,
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
- return b_->CreateFNeg(operand_value);
+ return FNeg(operand_value);
case HloOpcode::kReal:
return operand_value;
case HloOpcode::kImag:
return llvm::ConstantFP::get(operand_value->getType(), 0.0);
default:
return Unimplemented("unary floating-point op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType component_type =
primitive_util::IsComplexType(input_type)
@@ -496,12 +484,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
@@ -509,14 +496,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
- auto a_plus_one = b_->CreateFAdd(a, one);
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one),
- b_->CreateFMul(b, b));
+ auto a_plus_one = FAdd(a, one);
+ auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -530,11 +515,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
- return EmitComposeComplex(op,
- b_->CreateFPCast(EmitExtractReal(operand_value),
- to_ir_component_type),
- b_->CreateFPCast(EmitExtractImag(operand_value),
- to_ir_component_type));
+ return EmitComposeComplex(
+ op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
+ FPCast(EmitExtractImag(operand_value), to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
@@ -544,8 +527,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
- return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b),
- b_->CreateFMul(exp_a, sin_b));
+ return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
}
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
@@ -556,8 +538,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
- auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one);
- auto imag_result = b_->CreateFMul(exp_a, sin_b);
+ auto real_result = FSub(FMul(exp_a, cos_b), one);
+ auto imag_result = FMul(exp_a, sin_b);
return EmitComposeComplex(op, real_result, imag_result);
}
case HloOpcode::kCos: {
@@ -572,14 +554,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)),
- b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b)));
+ return EmitComposeComplex(op,
+ FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
+ FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
@@ -595,14 +576,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)),
- b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b)));
+ return EmitComposeComplex(op,
+ FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
+ FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kTanh: {
/*
@@ -630,74 +610,63 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
- auto exp_neg_a =
- b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
- auto exp_2a_minus_exp_neg_2a = b_->CreateFSub(
- b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a));
- auto cos_b_sq = b_->CreateFMul(cos_b, cos_b);
- auto sin_b_sq = b_->CreateFMul(sin_b, sin_b);
- auto real_num =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
- b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
- auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b);
- auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a =
+ FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = FMul(cos_b, cos_b);
+ auto sin_b_sq = FMul(sin_b, sin_b);
+ auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = FMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
auto exp_a_plus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
- auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a);
+ FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
auto exp_a_minus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
- auto imag_num = b_->CreateFMul(
- cos_b_sin_b,
- b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
- auto denom =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
- b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
- return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom),
- b_->CreateFDiv(imag_num, denom));
+ FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = FMul(
+ cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
+ auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, FDiv(real_num, denom),
+ FDiv(imag_num, denom));
}
case HloOpcode::kAbs: {
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
{sum_sq->getType()}, b_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero);
- return b_->CreateSelect(
+ auto oeq = FCmpOEQ(cplx_abs, zero);
+ return Select(
oeq, EmitComposeComplex(op, zero, zero),
- EmitComposeComplex(
- op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
- b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs)));
+ EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
+ FDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kNegate:
- return EmitComposeComplex(op,
- b_->CreateFNeg(EmitExtractReal(operand_value)),
- b_->CreateFNeg(EmitExtractImag(operand_value)));
+ return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
+ FNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
return EmitExtractReal(operand_value);
case HloOpcode::kImag:
return EmitExtractImag(operand_value);
default:
return Unimplemented("unary complex op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
operand_type == PRED) {
@@ -712,21 +681,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
- return b_->CreateFAdd(lhs_value, rhs_value);
+ return FAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateFSub(lhs_value, rhs_value);
+ return FSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateFMul(lhs_value, rhs_value);
+ return FMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return b_->CreateFDiv(lhs_value, rhs_value);
+ return FDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
- return b_->CreateFRem(lhs_value, rhs_value);
+ return FRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@@ -763,66 +731,52 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
default:
return Unimplemented("binary floating point op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kAdd:
- return EmitComposeComplex(op,
- b_->CreateFAdd(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFAdd(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
- return EmitComposeComplex(op,
- b_->CreateFSub(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFSub(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
return EmitComposeComplex(
op,
- b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value))));
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
auto rhs_sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(rhs_value),
- EmitExtractImag(rhs_value)));
+ FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero);
- auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero);
- auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero);
- return b_->CreateSelect(
+ auto oeq = FCmpOEQ(rhs_sum_sq, zero);
+ auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero);
+ auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero);
+ return Select(
oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
- EmitComposeComplex(
- op,
- b_->CreateFDiv(
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq),
- b_->CreateFDiv(
- b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq)));
+ EmitComposeComplex(op,
+ FDiv(FAdd(FMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq),
+ FDiv(FSub(FMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
@@ -832,21 +786,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
- return b_->CreateAnd(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kNe:
- return b_->CreateOr(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kPower: {
// (a+bi)^(c+di) =
@@ -858,45 +810,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
auto b = EmitExtractImag(lhs_value);
auto c = EmitExtractReal(rhs_value);
auto d = EmitExtractImag(rhs_value);
- auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
- auto half_c = b_->CreateFMul(one_half, c);
+ auto half_c = FMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
- auto neg_d = b_->CreateFNeg(d);
+ auto neg_d = FNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
- auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs);
+ auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
- auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
- auto half_d = b_->CreateFMul(one_half, d);
- auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs),
- b_->CreateFMul(half_d, ln_aa_p_bb));
+ auto half_d = FMul(one_half, d);
+ auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
- return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q),
- b_->CreateFMul(coeff, sin_q));
+ return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q));
}
default:
return Unimplemented("binary complex op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
- llvm::Value* x) const {
+ llvm::Value* x) {
if (prim_type != F32) {
// TODO(b/34339814): Implement inverse erf for F64.
return Unimplemented(
@@ -909,9 +859,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
- coefficients.pop_front();
+ coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), getFloat(coefficient));
}
return p;
};
@@ -931,25 +881,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::log, {b_->getFloatTy()});
- llvm::Value* w = b_->CreateFNeg(b_->CreateCall(
- logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x),
- b_->CreateFAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(
+ Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
llvm::Value* p_addr =
llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
+ FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
// Handle true BB.
SetToFirstInsertPoint(if_data.true_block, b_);
{
- llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f));
+ llvm::Value* lw = FSub(w, getFloat(2.5f));
tensorflow::gtl::ArraySlice<float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
llvm::Value* p = multiply_add(lq, lw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
// Handle false BB.
@@ -958,76 +907,73 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
- llvm::Value* gw =
- b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
tensorflow::gtl::ArraySlice<float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
llvm::Value* p = multiply_add(gq, gw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
SetToFirstInsertPoint(if_data.after_block, b_);
- llvm::Value* p = b_->CreateLoad(p_addr);
- return b_->CreateFMul(p, x);
+ llvm::Value* p = Load(p_addr);
+ return FMul(p, x);
}
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
+ llvm::Value* value) {
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
- return EmitErfInv(prim_type, b_->CreateFSub(one, value));
+ return EmitErfInv(prim_type, FSub(one, value));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
auto negative_half = llvm::ConstantFP::get(type, -0.5);
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
- TF_ASSIGN_OR_RETURN(auto for_large_x,
- EmitLog(prim_type, b_->CreateFAdd(x, one)));
+ TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
- auto for_small_x =
- b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x);
+ auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
+ auto x_is_small = FCmpOLT(
abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
- return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
@@ -1035,40 +981,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
- auto for_large_x = b_->CreateFSub(exp_x, one);
+ auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
- auto x_squared = b_->CreateFAdd(x, x);
- auto x_squared_over_two = b_->CreateFMul(x_squared, half);
- auto for_small_x = b_->CreateFAdd(x, x_squared_over_two);
+ auto x_squared = FAdd(x, x);
+ auto x_squared_over_two = FMul(x_squared, half);
+ auto for_small_x = FAdd(x, x_squared_over_two);
const auto kExponentIsSmallThreshold = 1e-5;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
- abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
- return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ auto x_is_small =
+ FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
+ return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return Unimplemented("atan2");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
- const HloInstruction* hlo, llvm::Value* x) const {
+ const HloInstruction* hlo, llvm::Value* x) {
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
@@ -1099,23 +1045,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
+llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
+}
+
+llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
+}
+
+llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
+ integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(
+ integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
+ return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
+}
+
+llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
+ llvm::Value* rhs) {
+ return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
+ ICmpEQ(rhs, GetMinusOne(rhs->getType())));
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) {
+ // Integer division overflow behavior:
+ //
+ // X / 0 == -1
+ // INT_SMIN /s -1 = INT_SMIN
+
+ if (!is_signed) {
+ llvm::Value* udiv_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = UDiv(lhs, safe_rhs);
+ return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = SDiv(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, GetMinusOne(lhs->getType()),
+ Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) {
+ // Integer remainder overflow behavior:
+ //
+ // X % 0 == X
+ // INT_SMIN %s -1 = 0
+
+ if (!is_signed) {
+ llvm::Value* urem_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = URem(lhs, safe_rhs);
+ return Select(urem_is_unsafe, lhs, safe_rem);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = SRem(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, lhs,
+ Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const {
+ bool is_signed) {
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
- return b_->CreateAdd(lhs_value, rhs_value);
+ return Add(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateSub(lhs_value, rhs_value);
+ return Sub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateMul(lhs_value, rhs_value);
+ return Mul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return is_signed ? b_->CreateSDiv(lhs_value, rhs_value)
- : b_->CreateUDiv(lhs_value, rhs_value);
+ return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
- return is_signed ? b_->CreateSRem(lhs_value, rhs_value)
- : b_->CreateURem(lhs_value, rhs_value);
+ return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
@@ -1143,11 +1169,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
case HloOpcode::kMaximum:
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
- return b_->CreateAnd(lhs_value, rhs_value);
+ return And(lhs_value, rhs_value);
case HloOpcode::kOr:
- return b_->CreateOr(lhs_value, rhs_value);
+ return Or(lhs_value, rhs_value);
case HloOpcode::kXor:
- return b_->CreateXor(lhs_value, rhs_value);
+ return Xor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
@@ -1156,43 +1182,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
// UB.
case HloOpcode::kShiftRightArithmetic:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateAShr(lhs_value, rhs_value),
+ AShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/true);
case HloOpcode::kShiftLeft:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateShl(lhs_value, rhs_value),
+ Shl(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
case HloOpcode::kShiftRightLogical:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateLShr(lhs_value, rhs_value),
+ LShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
default:
return Unimplemented("binary integer op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
- return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
- : llvm::ICmpInst::ICMP_UGE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ bool is_signed) {
+ return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
+ : llvm::ICmpInst::ICMP_UGE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
- return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
- : llvm::ICmpInst::ICMP_ULE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ bool is_signed) {
+ return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
+ : llvm::ICmpInst::ICMP_ULE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const {
+ int64 operand_no) {
CHECK(hlo.IsElementwise())
<< "HLO " << hlo.ToString() << " is not elementwise.";
@@ -1233,7 +1259,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const {
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) {
TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma,
@@ -1251,17 +1277,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// Perform the division using the float type with the same number of bits
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
- elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
- elem_value = b_->CreateFDiv(
- elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
+ elem_value = UIToFP(elem_value, b_->getFloatTy());
+ elem_value = FDiv(elem_value,
+ llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
} else {
- elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
- elem_value = b_->CreateFDiv(
+ elem_value = UIToFP(elem_value, b_->getDoubleTy());
+ elem_value = FDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) {
- elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
+ elem_value = FPTrunc(elem_value, elem_ir_ty);
}
}
@@ -1269,9 +1295,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
switch (hlo->random_distribution()) {
case RNG_UNIFORM: {
if (elem_ir_ty->isFloatingPointTy()) {
- return b_->CreateFAdd(
- b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value),
- a_or_mean);
+ return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean);
} else {
// To generate a uniform random value in [a, b) from a raw random sample
// in range [0, 2^N), we let range = b - a and return
@@ -1284,22 +1308,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// the same cost as if the whole warp were to re-sample. So an
// efficient re-sampling implementation on GPU would need to do
// nontrivial work to share entropy between threads in the warp.
- auto range = b_->CreateSub(b_or_sigma, a_or_mean);
- return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range));
+ auto range = Sub(b_or_sigma, a_or_mean);
+ return Add(a_or_mean, URem(elem_value, range));
}
}
case RNG_NORMAL: {
TF_ASSIGN_OR_RETURN(
llvm::Value * r,
- EmitErfcInv(elem_prim_ty,
- b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
- elem_value)));
- return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean);
+ EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
+ elem_value)));
+ return FAdd(FMul(r, b_or_sigma), a_or_mean);
}
default:
return InvalidArgument(
"unhandled distribution %s",
- RandomDistribution_Name(hlo->random_distribution()).c_str());
+ RandomDistribution_Name(hlo->random_distribution()));
}
}
@@ -1414,8 +1437,7 @@ std::array<llvm::Value*, 4> CalculateSampleValues(
// Precondition: the RNG instruction is not fused.
llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
VLOG(3) << "Using philox RNG algorithm";
CHECK(!hlo->IsFused());
// A random number generated by the per module random number generator.
@@ -1438,7 +1460,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Load the global state variable for the Philox RNG algorithm.
llvm::GlobalVariable* rng_state_ptr =
llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_);
- llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value");
+ llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value");
// Build and return the elemental IR generator to generate a random value for
// the element corresponding to the current thread.
@@ -1464,8 +1486,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// element within the sample.
llvm::Value* elems_per_sample_value =
llvm::ConstantInt::get(index_ty, elems_per_sample);
- llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value);
- llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value);
+ llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value);
+ llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value);
std::array<llvm::Value*, 4> counter_values = CalculateSampleValues(
sample_idx, hlo_random_value, global_random_number, rng_state, b_);
@@ -1473,18 +1495,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Store the four counter_values into the sample_address alloca so we can
// load the elem_offset'th one below.
for (int idx = 0; idx < 4; ++idx) {
- b_->CreateStore(counter_values[idx],
- b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx)));
+ Store(counter_values[idx],
+ InBoundsGEP(sample_address, b_->getInt32(idx)));
}
llvm::Type* int64_ty = b_->getInt64Ty();
CHECK(elems_per_sample == 2 || elems_per_sample == 4);
llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty;
// Retrieve the raw value for the current element from the current sample.
- llvm::Value* raw_elem_value = b_->CreateLoad(
- b_->CreateInBoundsGEP(
- b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()),
- elem_offset),
+ llvm::Value* raw_elem_value = Load(
+ InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()),
+ elem_offset),
"raw_elem_value");
return ConvertValueForDistribution(hlo, operand_to_generator, index,
@@ -1495,7 +1516,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1505,14 +1526,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
- return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()),
- on_true_value, on_false_value);
+ return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
+ on_false_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1531,14 +1552,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
- PrimitiveType_Name(prim_type).c_str());
+ PrimitiveType_Name(prim_type));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const {
+ const llvm_ir::IrArray::Index& target_index) {
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
@@ -1560,9 +1581,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
}
llvm_ir::SetToFirstInsertPoint(exit_block, b_);
- llvm::PHINode* output = b_->CreatePHI(
- llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
- hlo->operands().size());
+ llvm::PHINode* output =
+ PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ hlo->operands().size());
auto prior_insert_point = b_->GetInsertPoint();
b_->SetInsertPoint(init_block);
@@ -1577,9 +1598,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
auto concat_dim_size =
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
operand->shape().dimensions(concat_dim));
- b_->CreateCondBr(
- b_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
- true_block, false_block);
+ CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block,
+ false_block);
// Create the terminator of the true block before calling operand
// generators, because they require non-degenerate basic blocks.
@@ -1592,11 +1612,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
// Subtract the size of the concat dimension of the current operand
// from the source index.
b_->SetInsertPoint(false_block);
- source_index[concat_dim] =
- b_->CreateSub(source_index[concat_dim], concat_dim_size);
+ source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size);
}
- b_->CreateUnreachable();
+ Unreachable();
b_->SetInsertPoint(exit_block, prior_insert_point);
return output;
}
@@ -1604,7 +1623,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
@@ -1621,7 +1640,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
int64 largest_valid_start_index =
input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
CHECK_GE(largest_valid_start_index, 0);
@@ -1641,7 +1660,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
- input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]);
+ input_index[i] = Add(slice_start_index[i], index[i]);
}
return operand_to_generator.at(input_hlo)(input_index);
}
@@ -1649,7 +1668,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const Shape& operand_shape = hlo->operand(0)->shape();
const Shape& indices_shape = hlo->operand(1)->shape();
const Shape& output_shape = hlo->shape();
@@ -1672,7 +1691,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
- if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
@@ -1686,7 +1705,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.offset_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index.push_back(index[i]);
}
}
@@ -1698,7 +1717,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
- b_->CreateSExtOrTrunc(index_component, index_type);
+ SExtOrTrunc(index_component, index_type);
int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
@@ -1722,8 +1741,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
gather_dim_component_extended, is_signed),
is_signed);
- operand_index[operand_dim] = b_->CreateAdd(
- operand_index[operand_dim], gather_dim_component_extended_inbound);
+ operand_index[operand_dim] =
+ Add(operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
@@ -1747,7 +1766,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const HloInstruction* input_hlo = hlo->operand(0);
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
@@ -1770,7 +1789,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
int64 largest_valid_start_index =
@@ -1786,14 +1805,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
- slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size);
-
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]),
- "slice_intersection");
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]),
- "slice_intersection");
+ slice_limit_index[i] = Add(slice_start_index[i], update_dim_size);
+
+ slice_intersection =
+ And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]),
+ "slice_intersection");
+ slice_intersection =
+ And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]),
+ "slice_intersection");
}
// Emit:
@@ -1810,26 +1829,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
- update_index[i] = b_->CreateSub(index[i], slice_start_index[i]);
+ update_index[i] = Sub(index[i], slice_start_index[i]);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
- b_->CreateStore(true_value, ret_value_addr);
+ Store(true_value, ret_value_addr);
// Handle false BB (return data from 'input')
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
- b_->CreateStore(false_value, ret_value_addr);
+ Store(false_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const {
+ const llvm_ir::IrArray::Index& padded_index) {
auto index = padded_index;
llvm::Value* in_bounds = b_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
@@ -1837,26 +1856,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
- index[i] =
- b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low()));
- in_bounds = b_->CreateAnd(in_bounds,
- b_->CreateICmpSGE(index[i], index_typed_const(0)),
- "in_bounds");
- in_bounds = b_->CreateAnd(
+ index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds =
+ And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds");
+ in_bounds = And(
in_bounds,
- b_->CreateICmpEQ(
+ ICmpEQ(
index_typed_const(0),
- b_->CreateURem(index[i],
- index_typed_const(pad_dim.interior_padding() + 1))),
- "in_bounds");
- index[i] = b_->CreateSDiv(
- index[i], index_typed_const(pad_dim.interior_padding() + 1));
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpSLT(
- index[i],
- index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
+ index[i] =
+ SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1));
+ in_bounds =
+ And(in_bounds,
+ ICmpSLT(index[i],
+ index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ "in_bounds");
}
// if (in_bounds) {
@@ -1872,26 +1887,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
SetToFirstInsertPoint(if_data.true_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
- b_->CreateStore(operand_value, ret_value_addr);
+ Store(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(padding_value, ret_value_addr);
+ Store(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const {
+ const llvm_ir::IrArray::Index& dot_result_index) {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
@@ -1919,8 +1934,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
llvm::Value* accumulator_alloca =
llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
- b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
- accumulator_alloca);
+ Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
@@ -1942,42 +1956,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
}
rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
- llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca);
+ llvm::Value* current_accumulator = Load(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
- llvm::Value* product_real = b_->CreateFSub(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
- llvm::Value* product_imag = b_->CreateFAdd(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
- next_accumulator = b_->CreateInsertValue(
+ llvm::Value* product_real =
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
+ llvm::Value* product_imag =
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
+ next_accumulator = InsertValue(
current_accumulator,
- b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real),
- {0});
- next_accumulator = b_->CreateInsertValue(
+ FAdd(EmitExtractReal(current_accumulator), product_real), {0});
+ next_accumulator = InsertValue(
next_accumulator,
- b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag),
- {1});
+ FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
- next_accumulator = b_->CreateFAdd(current_accumulator,
- b_->CreateFMul(lhs_value, rhs_value));
+ next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
} else {
- next_accumulator =
- b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value));
+ next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
}
- b_->CreateStore(next_accumulator, accumulator_alloca);
+ Store(next_accumulator, accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
- return b_->CreateLoad(accumulator_alloca);
+ return Load(accumulator_alloca);
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -2071,10 +2080,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
auto source_index = target_index;
for (int64 dim : hlo->dimensions()) {
- source_index[dim] = b_->CreateSub(
- llvm::ConstantInt::get(target_index[dim]->getType(),
- hlo->shape().dimensions(dim) - 1),
- target_index[dim]);
+ source_index[dim] =
+ Sub(llvm::ConstantInt::get(target_index[dim]->getType(),
+ hlo->shape().dimensions(dim) - 1),
+ target_index[dim]);
}
return operand_to_generator.at(operand)(source_index);
};
@@ -2088,6 +2097,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
hlo->dimensions(), b_));
};
+ case HloOpcode::kIota:
+ return [this, hlo](
+ const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ auto* iota = Cast<HloIotaInstruction>(hlo);
+ PrimitiveType element_type = iota->shape().element_type();
+ IrArray::Index elem_index =
+ ShapeUtil::Rank(iota->shape()) > 1
+ ? target_index.SourceIndexOfBroadcast(
+ iota->shape(),
+ ShapeUtil::MakeShapeWithDescendingLayout(
+ element_type,
+ {iota->shape().dimensions(iota->iota_dimension())}),
+ {iota->iota_dimension()}, b_)
+ : target_index;
+ llvm::Value* elem_index_linear = elem_index.linear();
+ if (elem_index_linear == nullptr) {
+ std::vector<int64> iota_bound = {
+ iota->shape().dimensions(iota->iota_dimension())};
+ elem_index_linear = elem_index.Linearize(iota_bound, b_);
+ }
+ if (ShapeUtil::ElementIsIntegral(iota->shape())) {
+ return b_->CreateIntCast(
+ elem_index_linear,
+ llvm_ir::PrimitiveTypeToIrType(element_type, module_),
+ /*isSigned=*/false);
+ } else {
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape()))
+ << element_type;
+ llvm::Type* float_ir_type;
+ if (element_type == BF16) {
+ float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
+ } else {
+ float_ir_type =
+ llvm_ir::PrimitiveTypeToIrType(element_type, module_);
+ }
+ llvm::Value* float_val =
+ b_->CreateUIToFP(elem_index_linear, float_ir_type);
+ if (element_type == BF16) {
+ return EmitF32ToBF16(float_val, b_);
+ } else {
+ return float_val;
+ }
+ }
+ };
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
@@ -2153,28 +2206,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
};
}
}
-llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {0});
+llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
+ return ExtractValue(value, {0});
}
-llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {1});
+llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
+ return ExtractValue(value, {1});
}
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
llvm::Value* real,
- llvm::Value* imag) const {
+ llvm::Value* imag) {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto complex = b_->CreateInsertValue(
- llvm::ConstantAggregateZero::get(cplx_type), real, {0});
+ auto complex =
+ InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
- complex = b_->CreateInsertValue(complex, imag, {1});
+ complex = InsertValue(complex, imag, {1});
}
return complex;
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 1598a4dd85..d3e2acaabd 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -23,12 +23,13 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
-class ElementalIrEmitter {
+class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
public:
using HloToElementGeneratorMap =
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
@@ -40,100 +41,114 @@ class ElementalIrEmitter {
virtual ~ElementalIrEmitter() = default;
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
- llvm::Value* operand_value) const;
+ llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Returns a function to generate an element of the output of `hlo`, given a
// map of functions to generate elements of its operands.
virtual llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
- llvm::IRBuilder<>* b() const { return b_; }
- llvm::Module* module() const { return module_; }
+ llvm::IRBuilder<>* b() { return b_; }
+
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return b_; }
+
+ llvm::Module* module() { return module_; }
protected:
- virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
+
+ virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ llvm::Value* IsZero(llvm::Value* v);
+ llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* GetZero(llvm::Type* type);
+ llvm::Value* GetOne(llvm::Type* type);
+ llvm::Value* GetIntSMin(llvm::Type* type);
+ llvm::Value* GetMinusOne(llvm::Type* type);
+
+ llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed);
+ llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
- virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
- virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
- llvm::Value* x) const;
+ llvm::Value* x);
- virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
- virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
+ virtual llvm::Value* EmitExtractReal(llvm::Value* value);
+ virtual llvm::Value* EmitExtractImag(llvm::Value* value);
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
- llvm::Value* imag) const;
+ llvm::Value* imag);
// A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
// the target array index, computes the source array index of its
@@ -142,50 +157,50 @@ class ElementalIrEmitter {
// Precondition: `hlo` is an elementwise op.
llvm_ir::IrArray::Index ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const;
+ int64 operand_no);
// Identifier of the thread unique among all threads on the device
- virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); }
+ virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
StatusOr<llvm::Value*> EmitElementalSelect(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalClamp(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalConcatenate(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const;
+ const llvm_ir::IrArray::Index& target_index);
StatusOr<llvm::Value*> EmitElementalDynamicSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalGather(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalPad(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const;
+ const llvm_ir::IrArray::Index& padded_index);
StatusOr<llvm::Value*> EmitElementalDot(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const;
+ const llvm_ir::IrArray::Index& dot_result_index);
llvm::IRBuilder<>* const b_;
@@ -200,13 +215,13 @@ class ElementalIrEmitter {
// random number generation algorithm.
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
// Converts the raw value generated by a random number generation algorithm
// to the distribution requested by the RNG HloInstruction.
StatusOr<llvm::Value*> ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const;
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index addb016b04..5ab0756219 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index fd75847d0c..78edf918a4 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status.h"
@@ -22,7 +24,6 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
using tensorflow::gtl::ArraySlice;
@@ -76,8 +77,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
- ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
- &hlo_profile_index_map())
+ ? absl::make_unique<HloExecutionProfile>(&hlo_profile_printer_data(),
+ &hlo_profile_index_map())
: nullptr;
StatusOr<ScopedShapedBuffer> return_value =
@@ -154,9 +155,9 @@ Status Executable::DumpHloSnapshot() {
const string& directory_path =
module_config().debug_options().xla_dump_executions_to();
const auto& module = hlo_snapshot_->hlo().hlo_module();
- string filename = tensorflow::strings::Printf(
- "computation_%lld__%s__execution_%lld", module.id(),
- module.entry_computation_name().c_str(), ++execution_count_);
+ string filename =
+ absl::StrFormat("computation_%d__%s__execution_%d", module.id(),
+ module.entry_computation_name(), ++execution_count_);
return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_);
}
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index 228c3fac95..997db7c058 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend,
tensorflow::mutex_lock lock(execution_mutex_);
int64 handle = next_handle_++;
auto inserted = handle_to_execution_.emplace(
- handle,
- MakeUnique<AsyncExecution>(backend, std::move(streams), profile, result));
+ handle, absl::make_unique<AsyncExecution>(backend, std::move(streams),
+ profile, result));
CHECK(inserted.second);
ExecutionHandle execution_handle;
@@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
tensorflow::mutex_lock lock(execution_mutex_);
auto it = handle_to_execution_.find(handle.handle());
if (it == handle_to_execution_.end()) {
- return NotFound("no execution record for execution handle: %lld",
+ return NotFound("no execution record for execution handle: %d",
handle.handle());
}
handle_to_execution_.erase(handle.handle());
@@ -78,7 +78,7 @@ StatusOr<const AsyncExecution*> ExecutionTracker::Resolve(
tensorflow::mutex_lock lock(execution_mutex_);
auto it = handle_to_execution_.find(handle.handle());
if (it == handle_to_execution_.end()) {
- return NotFound("no execution record for execution handle: %lld",
+ return NotFound("no execution record for execution handle: %d",
handle.handle());
}
return it->second.get();
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index d3efab3614..3cccec9862 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -28,7 +28,7 @@ namespace xla {
// points-to analysis (see b/36865746 for details).
class FlattenCallGraph : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "flatten-call-graph"; }
+ absl::string_view name() const override { return "flatten-call-graph"; }
// Duplicates computations called from multiple call- or while-nodes to
// flatten the call graph.
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 9370c88710..3f1a881372 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -230,7 +231,7 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
for (int64 i = 0; i < slice_sizes.size(); i++) {
- if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
@@ -251,7 +252,7 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
int64 batch_idx_counter = 0;
int64 offset_idx_counter = output_rank - offset_dims.size();
for (int64 i = 0; i < output_rank; i++) {
- bool is_offset_dim = c_binary_search(offset_dims, i);
+ bool is_offset_dim = absl::c_binary_search(offset_dims, i);
if (is_offset_dim) {
permutation.push_back(offset_idx_counter++);
} else {
@@ -322,7 +323,7 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
return Unimplemented(
"Gather operations with more than 2147483647 gather indices are not "
"supported. This error occurred for %s.",
- gather_instr->ToString().c_str());
+ gather_instr->ToString());
}
TF_ASSIGN_OR_RETURN(
@@ -373,8 +374,8 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
std::vector<HloInstruction*> gather_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(gather_instrs),
- is_nontrivial_gather);
+ absl::c_copy_if(computation->instructions(),
+ std::back_inserter(gather_instrs), is_nontrivial_gather);
}
for (HloInstruction* inst : gather_instrs) {
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index c1fc8574da..7bd9ea5984 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -25,7 +25,7 @@ namespace xla {
// nevertheless have a minimum level of support.
class GatherExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "gather_expander"; }
+ absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 8ef72850dc..82290bfea8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -56,6 +56,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -91,6 +93,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -107,6 +110,8 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -126,6 +131,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -171,6 +177,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
@@ -180,6 +187,11 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
"@llvm//:support",
],
@@ -224,6 +236,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
],
@@ -243,6 +256,7 @@ cc_library(
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -257,6 +271,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -337,6 +352,10 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -373,6 +392,9 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -390,6 +412,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -420,7 +443,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:shape_inference",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
@@ -466,6 +489,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -483,6 +507,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -513,6 +538,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -544,6 +571,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:shape_inference",
+ "@com_google_absl//absl/memory",
],
)
@@ -600,6 +628,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
alwayslink = True, # Contains per-platform transfer manager registration
@@ -670,6 +699,9 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
],
alwayslink = True, # Contains compiler registration
@@ -702,8 +734,8 @@ cc_library(
":xfeed_queue",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -718,6 +750,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -756,6 +789,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/strings",
],
)
@@ -767,12 +801,12 @@ cc_library(
":stream_assignment",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/compiler/xla/service:hlo_scheduling",
+ "@com_google_absl//absl/memory",
],
)
@@ -789,6 +823,8 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -839,7 +875,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -868,9 +906,8 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_runner",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index 537295292b..528209abc7 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
const BufferAssignment* buffer_assignment, int device_ordinal,
DeviceMemoryAllocator* memory_allocator) {
const int64 num_buffers = buffer_assignment->Allocations().size();
- auto buffer_allocations = WrapUnique(new BufferAllocations(
+ auto buffer_allocations = absl::WrapUnique(new BufferAllocations(
num_buffers, device_ordinal, memory_allocator, buffer_assignment));
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
@@ -62,7 +62,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment !=
0) {
return InternalError(
- "Address of registered buffer %lld must be a multiple of %llx, but "
+ "Address of registered buffer %d must be a multiple of %x, but "
"was %p",
i, kEntryParameterAlignBytes, address.opaque());
}
@@ -83,7 +83,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
- "multiple of %llx, but was %p",
+ "multiple of 0x%x, but was %p",
kXlaAllocatedBufferAlignBytes, buffer.opaque());
}
// We do manual memory management within BufferAllocations. Be sure not
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
index 6a285a6b98..13c83c9199 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include <cmath>
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -74,9 +74,8 @@ ENTRY MaxDifference {
%error = f32[SIZE] divide(%sub_abs, %denominator)
ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
})";
- auto size_string = std::to_string(num_elements);
- return tensorflow::str_util::StringReplace(
- kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
+ return absl::StrReplaceAll(kF16CompHloText,
+ {{"SIZE", absl::StrCat(num_elements)}});
}
StatusOr<F16BufferComparator> F16BufferComparator::Create(
@@ -125,7 +124,7 @@ StatusOr<F16BufferComparator> F16BufferComparator::Create(
StatusOr<bool> F16BufferComparator::CompareEqualImpl(
se::DeviceMemory<Eigen::half> test_buffer) {
if (ref_buffer_.root_buffer().size() != test_buffer.size()) {
- return InternalError("Mismatched buffer size: %lld vs %lld",
+ return InternalError("Mismatched buffer size: %d vs %d",
ref_buffer_.root_buffer().size(), test_buffer.size());
}
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 5780e0af40..9ed523998b 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to retrieve predicate value on stream %p: %s.",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
// Execute the true or the false computation depending on the value of the
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 7833a4077e..eea31f3de1 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -17,12 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d76ca6698d..f7952787c1 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index e09cde9abf..6e2e330edd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -54,9 +54,7 @@ namespace gpu {
// BatchNormRewriter.
class CudnnBatchNormRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "cudnn_batchnorm_rewriter";
- }
+ absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 7b172812c3..bc3c6f72f6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,12 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index caeb89d78e..dbdf8e7a0e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,24 +14,25 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
namespace gpu {
namespace {
+using absl::optional;
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
-using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
public:
@@ -59,8 +60,8 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
- tensorflow::strings::Printf(
- "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ absl::StrFormat(
+ "Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
@@ -128,14 +129,14 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
string AlgorithmToString(const AlgorithmDesc& algo) {
if (algo.tensor_ops_enabled()) {
- return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+ return absl::StrCat(algo.algo_id(), "+TC");
}
- return tensorflow::strings::StrCat(algo.algo_id());
+ return absl::StrCat(algo.algo_id());
}
string NumBytesToString(int64 bytes) {
- return tensorflow::strings::StrCat(
- tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
+ return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
+ bytes, "B)");
}
// Acquires a process-global lock on the device pointed to by the given
@@ -361,7 +362,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
- instr->ToString().c_str());
+ instr->ToString());
}
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 8b7749628a..f76d273e8c 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -39,7 +39,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
Compiler* compiler)
: stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-algorithm-picker";
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 905b5ee876..0b1ee2dc33 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -234,6 +234,23 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
<< "Backward input convolution should reverse all kernel dimensions.";
return no_match_result;
}
+ } else if (reverse_filter->IsConstant()) {
+ // If the filter is a constant, we're willing to pattern-match to a
+ // backwards-input conv, on the theory that
+ //
+ // a) reversing a constant is free, and
+ // b) even if the user specified this filter as reverse(constant), we would
+ // long ago have constant-folded away the reverse.
+ //
+ // If the constant has any other uses, reversing it isn't entirely free,
+ // since we'd now have two constants to keep in memory. But hopefully it's
+ // free enough.
+ //
+ // TODO(jlebar): Should we do this even if the filter is not a constant?
+ // Reversing a non-constant filter is probably cheaper than padding the
+ // input!
+
+ // Nothing to do, just fall through.
} else {
// Possibly 1x1 filter.
for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
@@ -373,22 +390,25 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // Fuse the matched HLOs into a backward convolution instruction.
- //
- // If the reverse is omitted (for 1x1 filters) in the original pattern, we add
- // it back in the fusion instruction so that later passes (such as
- // PadInsertion) can handle such fusion instructions easily.
+ // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
+ // This simplifies things for our caller, and algebraic-simplifier will later
+ // remove any unnecessary reverses.
if (reverse_filter->opcode() != HloOpcode::kReverse) {
- reverse_filter = reverse_filter->parent()->AddInstruction(
+ // Create a double-reverse, which is a nop.
+ HloComputation* c = conv->parent();
+ reverse_filter = c->AddInstruction(
+ HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(
HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
AsInt64Slice(kernel_spatial_dims)));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
+
dnums.set_kernel_input_feature_dimension(
conv->convolution_dimension_numbers().kernel_output_feature_dimension());
dnums.set_kernel_output_feature_dimension(
conv->convolution_dimension_numbers().kernel_input_feature_dimension());
-
return std::make_tuple(true, new_window, dnums);
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index 0c0578d888..fbe7e98494 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -26,7 +26,7 @@ namespace gpu {
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
class CudnnConvolutionRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-rewriter";
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 65588b6aaf..46c23db465 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,10 +32,13 @@ namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
+using ::testing::_;
-class CudnnConvolutionRewriterTest : public HloTestBase {
+class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
public:
- CudnnConvolutionRewriterTest() {
+ CudnnConvolutionRewriterTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false) {
for (int i = 0; i < 2; ++i) {
WindowDimension* window_dim = default_conv_window_.add_dimensions();
window_dim->set_size(1);
@@ -114,7 +117,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -142,7 +145,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -172,7 +175,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -202,7 +205,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -230,7 +233,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -280,7 +283,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
@@ -325,7 +328,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -357,7 +360,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -410,7 +413,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -457,7 +460,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -510,7 +513,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
const HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -562,12 +565,38 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
}
+// Check that we will materialize a reversed version of a constant in order to
+// pattern-match a backwards input convolution.
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
+ Array4D<float> constant_arr(4, 4, 2, 2);
+ constant_arr.FillIota(0);
+ string constant_str =
+ LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ ParseAndVerifyModule(absl::StrFormat(R"(
+ HloModule test
+
+ ENTRY entry_computation {
+ param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
+ constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
+ ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
+ window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
+ dim_labels=bf01_01oi->bf01, feature_group_count=1
+ })",
+ constant_str));
+ EXPECT_TRUE(RunPass(&module()));
+ EXPECT_THAT(
+ module().entry_computation()->root_instruction(),
+ op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _,
+ op::Reverse(op::Constant())),
+ 0));
+}
+
} // anonymous namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 7b0d9e53d6..07b96fbd3f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator {
"Can't allocate twice from a ScratchBufAllocator.");
}
if (byte_size > scratch_.size()) {
- return se::port::InternalError(tensorflow::strings::StrCat(
+ return se::port::InternalError(absl::StrCat(
"Can't allocate ", byte_size,
" bytes from a ScratchBufAllocator of size ", scratch_.size()));
}
@@ -196,8 +197,8 @@ Status RunCudnnConvolution(
if (!stream->ok()) {
return InternalError(
- "Unable to launch convolution with type %s and algorithm (%lld, %lld)",
- CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(),
+ "Unable to launch convolution with type %s and algorithm (%d, %d)",
+ CudnnConvKindToString(kind), algorithm.algorithm().algo_id(),
algorithm.algorithm_no_scratch().algo_id());
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 9b6de115ad..57a3a43a6f 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
@@ -43,16 +45,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gpu {
+using absl::StrAppend;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrAppend;
namespace {
// Returns whether operand is a floating-point literal with the given value.
@@ -77,7 +77,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// The libdevice math functions differentiate between "double" and "float" by
// appending an 'f' to the function's name. libdevice doesn't have f16 math
// functions, so we convert the operands to f32 before calling the function
@@ -94,7 +94,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
converted_operands[i] =
- b_->CreateFPCast(converted_operands[i], b_->getFloatTy());
+ FPCast(converted_operands[i], b_->getFloatTy());
converted_input_types[i] = F32;
}
}
@@ -107,13 +107,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
break;
default:
return Unimplemented("Bad type for libdevice math call: %s",
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(output_type));
}
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
@@ -122,7 +122,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// llvm intrinsics differentiate between half/float/double functions via
// the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
@@ -138,7 +138,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
break;
default:
return Unimplemented("Bad type for llvm intrinsic math call: %s",
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(output_type));
}
return EmitMathCall(munged_callee, operands, input_types, output_type);
}
@@ -147,13 +147,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
return Unimplemented("Input type ≠ output type: %s ≠ %s",
- PrimitiveType_Name(input_type).c_str(),
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(input_type),
+ PrimitiveType_Name(output_type));
}
}
@@ -163,8 +163,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
@@ -183,8 +182,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
CHECK_EQ(op->opcode(), HloOpcode::kPower);
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
@@ -218,7 +216,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
// TODO(jlebar): Does this happen with fastmath disabled? If not, should
// we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
- return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
+ return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString();
@@ -227,55 +225,56 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+ PrimitiveType prim_type, llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
// Emit a fast approximation of tanh instead of calling __nv_tanh.
// __nv_tanh is particularly bad because it contains branches, thus
// preventing LLVM's load-store vectorizer from working its magic across a
@@ -285,9 +284,9 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* input = FPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
+ return FPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
@@ -295,7 +294,7 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const {
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) {
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
@@ -315,29 +314,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
callee->addFnAttr(attribute);
}
- return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+ return Call(callee, llvm_ir::AsArrayRef(operands));
}
-llvm::Value* GpuElementalIrEmitter::EmitThreadId() const {
- llvm::Value* block_id = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "block.id");
- llvm::Value* thread_id_in_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
- llvm::Value* threads_per_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
- return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block),
- thread_id_in_block);
+llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
+ llvm::Value* block_id =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ llvm::Value* thread_id_in_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ llvm::Value* threads_per_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kMap:
return [=, &operand_to_generator](
@@ -383,7 +381,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(init_value, accum_ptr);
+ Store(init_value, accum_ptr);
}
llvm::Type* index_type = index.GetType();
@@ -405,22 +403,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
IrArray::Index input_index(index_type, index.size());
llvm::Value* in_bounds = b_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* stridden_index = b_->CreateNSWMul(
+ llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
- input_index[i] = b_->CreateNSWSub(
- b_->CreateNSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ input_index[i] =
+ NSWSub(NSWAdd(stridden_index, window_index[i]),
+ index_typed_const(window.dimensions(i).padding_low()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
// comparison is equivalent to the unsigned comparison
// input_index[i] < bound, as a negative value wraps to a large
// positive value.
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpULT(
- input_index[i],
- index_typed_const(operand->shape().dimensions(i))));
+ in_bounds =
+ And(in_bounds,
+ ICmpULT(input_index[i],
+ index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
@@ -432,12 +429,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
operand_to_generator.at(operand)(input_index));
TF_ASSIGN_OR_RETURN(
llvm::Value * accum_value,
- compute_nested_(*hlo->to_apply(),
- {b_->CreateLoad(accum_ptr), input_value}));
- b_->CreateStore(accum_value, accum_ptr);
+ compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
+ Store(accum_value, accum_ptr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
- return b_->CreateLoad(accum_ptr);
+ return Load(accum_ptr);
};
case HloOpcode::kReduce:
// TODO(b/112040122): This should be supported.
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 84454d31bb..91942785d2 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -48,50 +48,50 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
- StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const override;
+ StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value) override;
StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
- llvm::Value* EmitThreadId() const override;
+ llvm::Value* EmitThreadId() override;
private:
// Emits IR for op, which must have opcode kPower.
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Emits IR to call a device function named "callee_name" on the given
// operand. Returns the IR value that represents the return value.
@@ -100,7 +100,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const;
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes);
// Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
@@ -109,7 +109,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
// Emits IR to call a libdevice function of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
@@ -118,7 +118,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
@@ -126,7 +126,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index 0cdddf8bcf..11549cdac5 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -43,8 +43,8 @@ StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
- tensorflow::strings::Printf(
- "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ absl::StrFormat(
+ "Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
@@ -213,7 +213,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
return Status::OK();
}
return InternalError("Unable to launch fft for thunk %p with type %s", this,
- FftTypeToString(fft_type_).c_str());
+ FftTypeToString(fft_type_));
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 8c53be5077..4adec7ee54 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 2fd2206324..88f0b4d71c 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 3cd30b754c..1bd88233e1 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -18,12 +18,13 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -64,10 +65,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) {
// Slice for a more accurate estimate of bytes read.
double bytes = 0.0;
for (auto& instruction : instructions) {
- if (c_all_of(instruction->users(), [](const HloInstruction* instruction) {
- return instruction->opcode() == HloOpcode::kSlice ||
- instruction->opcode() == HloOpcode::kDynamicSlice;
- })) {
+ if (absl::c_all_of(
+ instruction->users(), [](const HloInstruction* instruction) {
+ return instruction->opcode() == HloOpcode::kSlice ||
+ instruction->opcode() == HloOpcode::kDynamicSlice;
+ })) {
// All users are slice: accumulate bytes of all user slice instructions.
for (auto& user : instruction->users()) {
bytes += ShapeUtil::ByteSizeOf(user->shape());
@@ -223,7 +225,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
user->fusion_kind() == HloInstruction::FusionKind::kInput);
@@ -241,11 +243,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// If 'fusion' has just one user, then an earlier fusion pass chose not to
// fuse this producer/comsumer pair (likely because of expensive instruction
// re-use by the consumer), and so we honor that choice here as well.
- if (c_any_of(fusion->fused_instructions(),
- [](const HloInstruction* instruction) {
- return instruction->opcode() != HloOpcode::kParameter &&
- GpuInstructionFusion::IsExpensive(*instruction);
- })) {
+ if (absl::c_any_of(fusion->fused_instructions(),
+ [](const HloInstruction* instruction) {
+ return instruction->opcode() != HloOpcode::kParameter &&
+ GpuInstructionFusion::IsExpensive(*instruction);
+ })) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Contains one or more expensive instructions.";
++num_fail_expensive_fused_instruction_;
@@ -287,11 +289,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
<< " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion)
<< " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio
<< " into users { "
- << tensorflow::str_util::Join(users, ", ",
- [](string* out, HloInstruction* user) {
- tensorflow::strings::StrAppend(
- out, user->name());
- })
+ << absl::StrJoin(users, ", ",
+ [](string* out, HloInstruction* user) {
+ absl::StrAppend(out, user->name());
+ })
<< " }";
// Remove 'fusion' instruction.
CHECK_EQ(0, fusion->user_count());
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 4c523a66de..7e3f5775b8 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -34,7 +34,7 @@ namespace gpu {
//
class FusionMerger : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "fusion merger"; }
+ absl::string_view name() const override { return "fusion merger"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 74282c568c..9c4a490366 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -186,7 +186,7 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
}
return InternalError(
- "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms "
+ "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
"ran successfully",
stream, algorithms.size());
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 0c6f9b511f..8ffae18fe8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -27,7 +27,7 @@ namespace gpu {
// inserting kCopy instructions.
class GpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 7060837904..71a02e70df 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks(
//
// TODO(jlebar): Should we cache the results of HloInstruction::ToString(),
// since we expect it to be an expensive call?
- tensorflow::gtl::optional<ScopedAnnotation> op_annotation;
+ absl::optional<ScopedAnnotation> op_annotation;
if (top_level_annotation.IsEnabled()) {
op_annotation.emplace(
thunk->hlo_instruction() != nullptr
@@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks(
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
if (thunk_schedule_->Depended(thunk)) {
- auto finish_event = MakeUnique<se::Event>(main_stream->parent());
+ auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
@@ -160,7 +160,7 @@ Status GpuExecutable::ExecuteThunks(
if (!block_status.ok()) {
return InternalError(
"Failed to complete all kernels launched on stream %p: %s",
- main_stream, block_status.error_message().c_str());
+ main_stream, block_status.error_message());
}
}
@@ -260,10 +260,9 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
if (buffer.is_null() && buffer.size() > 0) {
return FailedPrecondition(
"Cannot run XLA computation because pointer to (sub-)buffer at "
- "index %s of parameter %lld was null. All pointers to "
- "(sub-)buffers must not be null, unless the (sub-)buffer has zero "
- "elements.",
- allocation.param_shape_index().ToString().c_str(), param_no);
+ "index %s of parameter %d was null. All pointers to (sub-)buffers "
+ "must not be null, unless the (sub-)buffer has zero elements.",
+ allocation.param_shape_index().ToString(), param_no);
}
buffer_allocations_builder.RegisterBuffer(i, buffer);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index c7ce6d0acb..627a05e240 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -32,10 +34,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
index 4944c41f7d..4268fb2c7a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
@@ -34,9 +34,8 @@ StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
return xla::Unimplemented(
"GPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
- instruction->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction->shape())
- .c_str());
+ instruction->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index d63e213d2b..bbb3340760 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface {
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "gpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "gpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 286547ebae..fbc8ddf599 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
// Enumerate all combinations of shapes.
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
for (int constrained_param_no : {0, 4}) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index a2f53f8446..f3c2744292 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/IR/DataLayout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -83,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
infeed_manager->EnqueueDestination(std::move(buffers));
@@ -96,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
se::StreamExecutor* executor, int64 size, const void* source) {
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
+ return InvalidArgument("Infeed shape is too large: needs %d bytes", size);
}
if (size == 0) {
@@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
if (ShapeUtil::IsTuple(shape)) {
return;
}
- *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ *buffer = absl::make_unique<gpu::OutfeedBuffer>(
+ GetByteSizeRequirement(shape));
(*buffer)->set_destination(
- MakeUnique<MutableBorrowingLiteral>(literal, index));
+ absl::make_unique<MutableBorrowingLiteral>(literal, index));
});
// Give the tree of buffers to the outfeed mananger. The device will fill it
@@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() {
- return xla::MakeUnique<xla::gpu::GpuTransferManager>(
+ return absl::make_unique<xla::gpu::GpuTransferManager>(
/*id=*/stream_executor::cuda::kCudaPlatformId,
/*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout)
.getPointerSize(0 /* default address space */));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 7929042869..fa88816bc8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
#include <vector>
@@ -61,4 +61,4 @@ class GpuTransferManager : public GenericTransferManager {
} // namespace gpu
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index 1722676930..b9c21e8edb 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -33,7 +34,7 @@ namespace gpu {
namespace {
void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
se::Stream* stream) {
- timers->push(MakeUnique<se::Timer>(stream->parent()));
+ timers->push(absl::make_unique<se::Timer>(stream->parent()));
stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
}
@@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler(
CHECK(hlo_instructions_.insert(hlo_instruction).second)
<< hlo_instruction->name();
}
- return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
+ return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index 19de37b0fb..76055ff009 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering(
: PredecessorHloOrdering(module) {
// The entry computation has a total order when there's only one stream.
if (stream_assignment.StreamCount() == 1) {
- entry_sequence_ =
- MakeUnique<std::vector<const HloInstruction*>>(thunk_launch_order);
+ entry_sequence_ = absl::make_unique<std::vector<const HloInstruction*>>(
+ thunk_launch_order);
}
// The ordering of instructions for the entry computation is determined by the
@@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering(
// same-stream predecessors of each instruction.
// Compute the set of all instructions we will want to set reachability on.
- auto predecessor_map = MakeUnique<HloReachabilityMap>(
+ auto predecessor_map = absl::make_unique<HloReachabilityMap>(
module->entry_computation()->MakeInstructionPostOrder());
// The most recently visited instruction per stream.
@@ -208,7 +208,7 @@ StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
}
- schedule->hlo_ordering_ = MakeUnique<GpuHloOrdering>(
+ schedule->hlo_ordering_ = absl::make_unique<GpuHloOrdering>(
&module, stream_assignment, schedule->thunk_launch_order_);
return std::move(schedule);
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index 45f0a1c645..bb147c8d98 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <algorithm>
#include <unordered_set>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -47,7 +49,7 @@ class HloScheduleTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
HloVec RemoveHlo(const HloVec& input,
@@ -265,7 +267,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
- i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 8c11cd0541..0e205b9c02 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -24,16 +25,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index c5f0cdf6cd..a4364b0deb 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
namespace xla {
namespace gpu {
@@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
tensorflow::mutex_lock l(host_to_device_stream_mu_);
if (host_to_device_executor_ == nullptr) {
host_to_device_executor_ = executor;
- host_to_device_stream_ = MakeUnique<se::Stream>(executor);
+ host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_->Init();
}
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index fee6d2af3b..8c3a026740 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
VLOG(2) << "Infeeding to GPU complete";
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 0f2c83aeb2..0bcaaee2b7 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -26,7 +26,7 @@ namespace gpu {
namespace {
-bool IsFusile(const HloInstruction& hlo) {
+bool IsFusible(const HloInstruction& hlo) {
// Don't fuse get-tuple-element on GPU: We can, but it's slower than not
// fusing. We never generate kernels for unfused GTEs. Instead, if an
// unfused GTE is an input to a kernel (including a fusion kernel), we
@@ -245,7 +245,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return true;
}
- if (!IsFusile(*producer) || !IsFusile(*consumer) ||
+ if (!IsFusible(*producer) || !IsFusible(*consumer) ||
!InstructionFusion::ShouldFuse(consumer, operand_index)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 8d0522bd8f..f53dfaee3d 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -365,7 +365,7 @@ static StatusOr<const HloInstruction*> FindHloInstruction(
}
return NotFound(
"Computation '%s' does not contain an instruction with op code '%s'.",
- computation.name().c_str(), HloOpcodeString(op).c_str());
+ computation.name(), HloOpcodeString(op));
}
TEST_F(InstructionFusionTest, MultiOutputFusion) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index c349063c71..f544bcc919 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -215,7 +215,7 @@ bool IsReductionToVector(const HloInstruction& reduce) {
// This emits a device-side call to
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+llvm::Value* EmitPrintf(absl::string_view fmt,
tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 5d23a3d018..a35e250101 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -126,7 +126,7 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo);
bool IsReductionToVector(const HloInstruction& reduce);
// Emits call to "vprintf" with given format and arguments.
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+llvm::Value* EmitPrintf(absl::string_view fmt,
tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
llvm::IRBuilder<>* builder);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 6675dbd3f9..bdf6aadde6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -155,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation(
std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
arguments.push_back(output);
arguments.push_back(bindings_.GetTempBufferBase());
- b_.CreateCall(emitted_function, arguments);
+ Call(emitted_function, arguments);
return Status::OK();
}
@@ -177,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
computation.root_instruction()->shape().element_type();
bool is_atomic_integral = element_type == S32 || element_type == U32 ||
element_type == S64 || element_type == U64;
- llvm::Value* source = b_.CreateLoad(source_address, "source");
+ llvm::Value* source = Load(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
if (element_type == F32) {
@@ -189,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
}
if (is_atomic_integral) {
// integral + integral
- b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
}
@@ -201,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::UMax;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -211,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::UMin;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -291,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// cas_old_output_address and cas_new_output_address point to the scratch
// memory where we store the old and new values for the repeated atomicCAS
// operations.
- llvm::Value* cas_old_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
- llvm::Value* cas_new_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
+ llvm::Value* cas_old_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
+ llvm::Value* cas_new_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
// Emit preparation code to the preheader.
llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
@@ -308,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
CHECK_EQ((element_size % sizeof(char)), 0);
llvm::Type* address_int_type =
module_->getDataLayout().getIntPtrType(output_address_type);
- atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type);
+ atomic_memory_address = PtrToInt(output_address, address_int_type);
llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
- llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask);
+ llvm::Value* offset = And(atomic_memory_address, mask);
mask = llvm::ConstantInt::get(address_int_type, -4);
- atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask);
+ atomic_memory_address = And(atomic_memory_address, mask);
atomic_memory_address =
- b_.CreateIntToPtr(atomic_memory_address, atomic_address_type);
- binop_output_address = b_.CreateAdd(
- b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset);
+ IntToPtr(atomic_memory_address, atomic_address_type);
binop_output_address =
- b_.CreateIntToPtr(binop_output_address, element_address_type);
+ Add(PtrToInt(cas_new_output_address, address_int_type), offset);
+ binop_output_address = IntToPtr(binop_output_address, element_address_type);
} else {
- atomic_memory_address =
- b_.CreateBitCast(output_address, atomic_address_type);
+ atomic_memory_address = BitCast(output_address, atomic_address_type);
binop_output_address =
- b_.CreateBitCast(cas_new_output_address, element_address_type);
+ BitCast(cas_new_output_address, element_address_type);
}
// Use the value from the memory that atomicCAS operates on to initialize
// cas_old_output.
- llvm::Value* cas_old_output =
- b_.CreateLoad(atomic_memory_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_old_output_address);
+ llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
+ Store(cas_old_output, cas_old_output_address);
llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
b_.GetInsertPoint(), "atomic_op_loop_exit");
@@ -343,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// Emit the body of the loop that repeatedly invokes atomicCAS.
//
// Use cas_old_output to initialize cas_new_output.
- cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_new_output_address);
+ cas_old_output = Load(cas_old_output_address, "cas_old_output");
+ Store(cas_old_output, cas_new_output_address);
// Emits code to calculate new_output = operation(old_output, source);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
computation, {binop_output_address, source_address},
binop_output_address));
- llvm::Value* cas_new_output =
- b_.CreateLoad(cas_new_output_address, "cas_new_output");
+ llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
// Emit code to perform the atomicCAS operation
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
// cas_new_output);
- llvm::Value* ret_value = b_.CreateAtomicCmpXchg(
- atomic_memory_address, cas_old_output, cas_new_output,
- llvm::AtomicOrdering::SequentiallyConsistent,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ llvm::Value* ret_value =
+ AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ llvm::AtomicOrdering::SequentiallyConsistent);
// Extract the memory value returned from atomicCAS and store it as
// cas_old_output.
- b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"),
- cas_old_output_address);
+ Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
// Extract the success bit returned from atomicCAS and generate a
// conditional branch on the success bit.
- b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
- loop_body_bb);
+ CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
// Set the insertion point to the exit basic block so that the caller of
// this method can continue emitting code to the right place.
@@ -383,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation(
// TODO(b/30258929): We only accept binary computations so far.
return Unimplemented(
"We only support atomic functions with exactly two parameters, but "
- "computation %s has %lld.",
- computation.name().c_str(), computation.num_parameters());
+ "computation %s has %d.",
+ computation.name(), computation.num_parameters());
}
if (MaybeEmitDirectAtomicOperation(computation, output_address,
@@ -471,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
- result = b_.CreateInsertValue(result, value.first, {0});
- result = b_.CreateInsertValue(result, value.second, {1});
+ result = InsertValue(result, value.first, {0});
+ result = InsertValue(result, value.second, {1});
} else {
- result = b_.CreateFMul(lhs_value, rhs_value);
+ result = FMul(lhs_value, rhs_value);
}
target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
return Status::OK();
@@ -518,7 +513,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// We don't have to iterate over the batch dimensions in both arrays, simplify
// the loop nest of the rhs.
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
- DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
+ DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
rhs_index[i] = lhs_index[i];
}
@@ -558,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
- llvm::Value* accum = b_.CreateLoad(accum_address);
+ llvm::Value* accum = Load(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
llvm::Value* accum_real = Real(accum, &b_);
- llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first);
- updated_accum = b_.CreateInsertValue(accum, real_sum, {0});
+ llvm::Value* real_sum = FAdd(accum_real, value.first);
+ updated_accum = InsertValue(accum, real_sum, {0});
llvm::Value* accum_imag = Imag(accum, &b_);
- llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second);
- updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1});
+ llvm::Value* imag_sum = FAdd(accum_imag, value.second);
+ updated_accum = InsertValue(updated_accum, imag_sum, {1});
} else {
- llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element);
- updated_accum = b_.CreateFAdd(accum, product);
+ llvm::Value* product = FMul(lhs_element, rhs_element);
+ updated_accum = FAdd(accum, product);
}
- b_.CreateStore(updated_accum, accum_address);
+ Store(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
// address. The index into the target address is the concatenation of the rhs
@@ -594,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
target_array.EmitWriteArrayElement(
target_index,
- b_.CreateLoad(accum_address), // The value written to the target array.
+ Load(accum_address), // The value written to the target array.
&b_);
// Set the IR builder insert point to the exit basic block of the outer most
@@ -645,10 +640,9 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
- b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ Alloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), module_));
- b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)),
- accumulator_addr);
+ Store(Load(GetBasePointer(*init_value)), accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
@@ -685,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
*function, {accumulator_addr, input_address}, accumulator_addr));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
});
}
@@ -752,11 +746,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleIota(HloInstruction*) {
- // TODO(b/64798317): implement iota on GPU.
- return Unimplemented("Iota is not implemented on GPU.");
-}
-
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
@@ -768,11 +757,11 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
- b_.CreateStore(parameter_element, parameter_buffers.back());
+ Store(parameter_element, parameter_buffers.back());
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
- return b_.CreateLoad(return_buffer);
+ return Load(return_buffer);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 561c683879..3673b9f58d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@@ -35,12 +36,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
@@ -64,7 +65,8 @@ namespace gpu {
// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
// generator generator. See comments on that class.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
IrEmitter(const IrEmitter&) = delete;
IrEmitter& operator=(const IrEmitter&) = delete;
@@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
- Status HandleIota(HloInstruction* iota) override;
Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
+ llvm::IRBuilder<>* builder() { return &b_; }
+
protected:
// Constructs an IrEmitter with the given IrEmitter context.
// ir_emitter_context is owned by the caller and should outlive the IrEmitter
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1e81cbde35..c0c8ae181a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -21,6 +21,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -29,7 +34,6 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
@@ -77,7 +81,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -85,13 +88,13 @@ namespace gpu {
namespace {
+using absl::InlinedVector;
+using absl::nullopt;
+using absl::optional;
+using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::InlinedVector;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
-using tensorflow::strings::StrCat;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
@@ -314,13 +317,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
};
// Check the size of input tensors
- if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
+ if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
return i64_ty;
}
// Check the size of the internal result tensors
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
- if (!c_all_of(
+ if (!absl::c_all_of(
unnested_hlo->fused_instructions_computation()->instructions(),
hlo_shape_in_range)) {
return i64_ty;
@@ -383,7 +386,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
int64 feature_index_value = feature_index->literal().Get<int64>({});
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardInferenceThunk>(
+ absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -413,7 +416,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardTrainingThunk>(
+ absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -443,19 +446,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_grad_offset =
assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
- thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>(
- /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
- /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
- /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
- /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
- /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
- /*epsilon=*/epsilon_value,
- /*feature_index=*/feature_index_value,
- /*output_grad_data=*/output_grad_data,
- /*output_grad_scale=*/output_grad_scale,
- /*output_grad_offset=*/output_grad_offset,
- /*output_tuple=*/GetAllocationSlice(*custom_call),
- /*hlo=*/custom_call));
+ thunk_sequence_->emplace_back(
+ absl::make_unique<CudnnBatchNormBackwardThunk>(
+ /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
+ /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
+ /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
+ /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
+ /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
+ /*epsilon=*/epsilon_value,
+ /*feature_index=*/feature_index_value,
+ /*output_grad_data=*/output_grad_data,
+ /*output_grad_scale=*/output_grad_scale,
+ /*output_grad_offset=*/output_grad_offset,
+ /*output_tuple=*/GetAllocationSlice(*custom_call),
+ /*hlo=*/custom_call));
return Status::OK();
}
@@ -475,7 +479,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
const auto& target = custom_call->custom_call_target();
std::unique_ptr<ConvolutionThunk> thunk;
if (target == kCudnnConvForwardCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kForward,
/*input_buffer=*/lhs_slice,
/*filter_buffer=*/rhs_slice,
@@ -489,7 +493,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
backend_config.algorithm(), backend_config.tensor_ops_enabled(),
custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
/*input_buffer=*/conv_result_slice,
/*filter_buffer=*/rhs_slice,
@@ -503,7 +507,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
backend_config.algorithm(), backend_config.tensor_ops_enabled(),
custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
/*input_buffer=*/lhs_slice,
/*filter_buffer=*/conv_result_slice,
@@ -576,7 +580,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunks.push_back(
BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), fusion));
+ absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
@@ -725,7 +729,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- b_.CreateStore(extra_output_ir_value, extra_output_address);
+ Store(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
@@ -798,8 +802,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize),
// //
// // and threads_per_block is a multiple of warpSize.
- // reduce_kernel<<<num_blocks, threads_per_block>>>();
- //
+ // reduce_kernel //
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
@@ -807,17 +810,17 @@ Status IrEmitterUnnested::EmitReductionToScalar(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
@@ -829,15 +832,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)),
+ tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
@@ -846,11 +848,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
IrArray::Index input_index(
/*linear=*/x, input_shape, &b_);
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -861,14 +863,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileSize),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileSize),
+ NSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
- b_.getInt1(all_threads_in_bounds));
+ Or(ICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
@@ -889,20 +891,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -917,10 +917,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id =
- b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
@@ -1040,12 +1039,12 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1056,8 +1055,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
@@ -1069,34 +1068,32 @@ Status IrEmitterUnnested::EmitColumnReduction(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* y = b_.CreateNSWAdd(
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* y =
+ NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
+ tile_element_loop->GetIndVarValue());
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
- &b_);
+ ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1123,7 +1120,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
@@ -1138,20 +1135,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
- llvm::Value* y_end = b_.CreateNSWAdd(
- index_typed_constant(kTileHeight),
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
+ llvm::Value* y_end =
+ NSWAdd(index_typed_constant(kTileHeight),
+ NSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileWidth),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileWidth),
+ NSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds =
- b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
- b_.getInt1(height % kTileHeight == 0));
+ Or(ICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
llvm::Value* tile_in_x_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
- b_.getInt1(width % kTileWidth == 0));
+ Or(ICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
@@ -1185,9 +1182,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
@@ -1376,11 +1373,11 @@ Status IrEmitterUnnested::EmitRowReduction(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1389,22 +1386,20 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = ZExtOrTrunc(x_tile, index_ty);
llvm::Value* warp_id =
- b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
+ UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id =
- b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = b_.CreateNSWAdd(
+ llvm::Value* last_x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- index_typed_constant(x_tile_size - 1),
- b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(index_typed_constant(x_tile_size - 1),
+ NSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&b_,
@@ -1416,9 +1411,8 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = b_.CreateNSWAdd(
- z_indvar,
- b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
+ llvm::Value* z =
+ NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
/*start=*/index_typed_constant(0),
@@ -1426,22 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = b_.CreateNSWAdd(
+ llvm::Value* x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- x_indvar, b_.CreateNSWMul(
- warp_id, llvm::ConstantInt::get(
- index_ty, x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(x_indvar,
+ NSWMul(warp_id, llvm::ConstantInt::get(
+ index_ty, x_tile_size)))));
// Unless we know the x-tile is entirely in bounds, we have to
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)),
- "x_in_bounds", &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
// Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&b_);
@@ -1449,7 +1441,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1480,7 +1472,7 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1500,8 +1492,8 @@ Status IrEmitterUnnested::EmitRowReduction(
};
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- b_.CreateICmpULT(last_x, index_typed_constant(width)));
+ Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ ICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1529,20 +1521,18 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1557,8 +1547,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
@@ -1718,7 +1707,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
thunks.push_back(
BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), reduce));
+ absl::make_unique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
reduce, input->shape(), {[&](const IrArray::Index& index) {
@@ -1738,7 +1727,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
bool all_tuple_elements_have_buffer =
- c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
+ absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
return ir_emitter_context_->buffer_assignment()
.GetUniqueTopLevelSlice(tuple_element)
.ok();
@@ -1760,7 +1749,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* tuple_element : tuple->operands()) {
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
}
- thunk_sequence_->emplace_back(MakeUnique<TupleThunk>(
+ thunk_sequence_->emplace_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
@@ -1792,8 +1781,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
thunks.push_back(std::move(initializer_thunk));
thunks.push_back(BuildKernelThunk(select_and_scatter,
/*implements_whole_instruction=*/false));
- thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
+ thunk_sequence_->emplace_back(absl::make_unique<SequentialThunk>(
+ std::move(thunks), select_and_scatter));
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
@@ -1842,7 +1831,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
&b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
@@ -1863,15 +1852,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
+ llvm::Value* strided_index = NSWMul(
source_index[i], index_typed_constant(window.dimensions(i).stride()));
- operand_index[i] = b_.CreateNSWSub(
- b_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_constant(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
+ operand_index[i] =
+ NSWSub(NSWAdd(strided_index, window_index[i]),
+ index_typed_constant(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ICmpULT(
operand_index[i],
index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -1881,7 +1870,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -1889,16 +1878,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
@@ -1914,11 +1903,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = b_.CreateLoad(select_return_buffer);
+ llvm::Value* result = Load(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
@@ -1927,7 +1916,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -1939,8 +1928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
@@ -2018,7 +2007,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
thunks.push_back(std::move(rng_thunk));
thunks.push_back(std::move(increment_seed_thunk));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), rng));
+ absl::make_unique<SequentialThunk>(std::move(thunks), rng));
return Status::OK();
}
@@ -2043,7 +2032,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
if (keys_destination != GetAllocationSlice(*keys)) {
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*keys),
/*destination_buffer=*/keys_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr));
@@ -2051,7 +2040,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
if (values != nullptr && values_destination != GetAllocationSlice(*values)) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*values),
/*destination_buffer=*/values_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr));
@@ -2095,15 +2084,15 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index),
- values != nullptr ? tensorflow::gtl::make_optional<IrArray>(
+ values != nullptr ? absl::make_optional<IrArray>(
GetIrArray(*sort, *sort, values_shape_index))
- : tensorflow::gtl::nullopt,
+ : absl::nullopt,
IrName(sort), xor_mask, &b_, &launch_dimensions));
}
}
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ absl::make_unique<SequentialThunk>(std::move(thunks), sort));
return Status::OK();
}
@@ -2130,7 +2119,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
if (crs->operand_count() == 1) {
CHECK(ShapeUtil::IsArray(crs->operand(0)->shape()))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunk_sequence_->push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
@@ -2145,17 +2134,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(crs, {i})
.ValueOrDie());
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
- thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), nullptr));
+ thunks.push_back(absl::make_unique<TupleThunk>(
+ tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
- MakeUnique<SequentialThunk>(std::move(thunks), crs));
+ absl::make_unique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
}
@@ -2305,7 +2294,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
for (const auto& kv : hlo_slices) {
buffers_needed.insert(kv.second.first.allocation());
}
- tensorflow::gtl::optional<const BufferAllocation*> temp_buffer;
+ absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
if (alloc.IsPreallocatedTempBuffer()) {
if (!temp_buffer.has_value()) {
@@ -2322,10 +2311,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// We'll pass a pointer to each of the elements of `buffers` to our kernel, in
// this order.
std::vector<const BufferAllocation*> non_constant_buffers;
- c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
- [](const BufferAllocation* allocation) {
- return !allocation->is_constant();
- });
+ absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
+ [](const BufferAllocation* allocation) {
+ return !allocation->is_constant();
+ });
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
[](const BufferAllocation* a, const BufferAllocation* b) {
@@ -2364,8 +2353,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
*slice.allocation())));
CHECK_NE(loc, nullptr);
} else {
- loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {b_.getInt64(slice.offset())});
+ loc = InBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
}
// If gte_index is nonempty, we have to dereference `loc` to get to the
@@ -2373,8 +2362,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::Type* int8_double_pointer =
llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = b_.CreateBitCast(loc, int8_double_pointer);
- loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
+ loc = BitCast(loc, int8_double_pointer);
+ loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -2389,7 +2378,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
- return MakeUnique<KernelThunk>(
+ return absl::make_unique<KernelThunk>(
non_constant_buffers, llvm_ir::AsString(kernel->getName()),
implements_whole_instruction ? inst : nullptr, unroll_factor);
}
@@ -2398,7 +2387,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
CHECK_EQ(HloOpcode::kConstant, operand->opcode());
- return MakeUnique<HostToDeviceCopyThunk>(
+ return absl::make_unique<HostToDeviceCopyThunk>(
/*source_address=*/operand->literal().untyped_data(),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2410,7 +2399,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<DeviceToDeviceCopyThunk>(
+ return absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*operand),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2430,7 +2419,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
});
- return MakeUnique<InfeedThunk>(slices, inst);
+ return absl::make_unique<InfeedThunk>(slices, inst);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
@@ -2447,7 +2436,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
*slice = status_or_slice.ConsumeValueOrDie();
}
});
- return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+ return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
}
namespace {
@@ -2470,7 +2459,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (inst->opcode() == HloOpcode::kDot) {
const HloInstruction* lhs = inst->operand(0);
const HloInstruction* rhs = inst->operand(1);
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2512,7 +2501,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* rhs =
inst->operand(rhs_parameter->parameter_number());
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2529,11 +2518,12 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
- /*input_buffer=*/GetAllocationSlice(*operand),
- /*output_buffer=*/GetAllocationSlice(*inst),
- /*input_shape=*/operand->shape(),
- /*output_shape=*/inst->shape(), inst);
+ return absl::make_unique<FftThunk>(
+ inst->fft_type(), inst->fft_length(),
+ /*input_buffer=*/GetAllocationSlice(*operand),
+ /*output_buffer=*/GetAllocationSlice(*inst),
+ /*input_shape=*/operand->shape(),
+ /*output_shape=*/inst->shape(), inst);
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
@@ -2582,9 +2572,9 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// MemzeroThunk.
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
- if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {
- MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
+ if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
+ return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
+ nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2601,7 +2591,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
@@ -2612,7 +2602,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
literal_bytes.size() - 4) == 0) {
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
@@ -2670,8 +2660,7 @@ Status CheckHloBuffersShareAllocation(
if (slice_a != slice_b) {
return InternalError(
"instruction %s %s does not share allocation with instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
+ a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
}
return Status::OK();
}
@@ -2764,7 +2753,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<WhileThunk>(
+ return absl::make_unique<WhileThunk>(
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence(), hlo);
@@ -2782,8 +2771,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<ForThunk>(loop_limit,
- ir_emitter_body.ConsumeThunkSequence(), hlo);
+ return absl::make_unique<ForThunk>(
+ loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
@@ -2803,7 +2792,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
ir_emitter_context_);
TF_CHECK_OK(false_computation->Accept(&ir_emitter_false));
- return MakeUnique<ConditionalThunk>(
+ return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo->operand(1)),
GetAllocationSlice(*hlo->operand(2)),
@@ -3105,7 +3094,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
}
const int64 num_tiles =
- c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
llvm::Type* index_ty =
@@ -3151,9 +3140,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
const IrArray::Index output_tile_origin = [&] {
IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] =
- b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
- "tile_origin." + std::to_string(i));
+ index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
}
return index;
}();
@@ -3166,12 +3154,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
- output_tile_bounds[i] = b_.CreateSelect(
- b_.CreateICmpEQ(output_tile_index[i],
- index_typed_constant(output_dims_in_tiles[i] - 1)),
- index_typed_constant(reduced_output_dims[i] -
- (output_dims_in_tiles[i] - 1) * kTileSize),
- index_typed_constant(kTileSize), "kTileSize");
+ output_tile_bounds[i] =
+ Select(ICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
}
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
@@ -3189,7 +3177,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
- index[dim] = b_.CreateAdd(index[dim], addend);
+ index[dim] = Add(index[dim], addend);
return index;
};
const IrArray::Index input_index =
@@ -3205,10 +3193,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* shmem_buffer = param_shmem_buffers[id];
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
// global variables, so LLVM can't infer much about it.
- b_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &b_,
- "input_element"),
- b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
+ Store(input_in_logical_shape.EmitReadArrayElement(index, &b_,
+ "input_element"),
+ GEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
}
});
@@ -3229,9 +3216,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
// TODO(jlebar): Add AA metadata to this load.
- llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
- b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
- "output_element");
+ llvm::Instruction* load_from_shmem_buffer =
+ Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
+ "output_element");
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
index, load_from_shmem_buffer, &b_);
});
@@ -3259,7 +3246,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_in_reduced_shape_arrays.size());
for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, b_.CreateExtractValue(output_value, i), &b_);
+ index, ExtractValue(output_value, i), &b_);
}
} else {
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
@@ -3341,7 +3328,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
// if there's a Right Choice.
//
// This is only sound if tiled transposes are the only place where we use
- // shared memory in fusions. If in the future other fusile ops use shared
+ // shared memory in fusions. If in the future other fusible ops use shared
// memory, we'll have to adjust this heuristic.
constexpr int kMinBlocksPerCore = 3;
constexpr int64 kShmemPerCore = 48 * 1024;
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index e76823ad10..3259eaa2a2 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -41,8 +41,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
tensorflow::mutex_lock lock(mutex_);
if (!loader_spec_) {
loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
- tensorflow::StringPiece ptx = executable.ptx();
- // Convert tensorflow::StringPiece to se::port::StringPiece because
+ absl::string_view ptx = executable.ptx();
+ // Convert absl::string_view to se::port::StringPiece because
// StreamExecutor uses the latter.
loader_spec_->AddCudaPtxInMemory(
se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
@@ -63,7 +63,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
if (kernel_cache_.end() == it) {
it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
if (!executor->GetKernel(*loader_spec_, &it->second)) {
- return InternalError("Unable to load kernel %s", kernel_name_.c_str());
+ return InternalError("Unable to load kernel %s", kernel_name_);
}
}
@@ -95,7 +95,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << "Launching " << kernel->name();
// Launch the kernel with potentially multiple blocks and threads.
static constexpr int kKernelArgsLimit = 1024;
- auto kernel_args = MakeUnique<se::KernelArgsArray<kKernelArgsLimit>>();
+ auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
for (const BufferAllocation* arg : args_) {
const auto& buf = buffer_allocations.GetDeviceAddress(arg->index());
kernel_args->add_device_memory_argument(buf);
@@ -107,7 +107,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
se::BlockDim(launch_dimensions.block_count()), *kernel,
*kernel_args)) {
- return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
+ return InternalError("Unable to launch kernel %s", kernel_name_);
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index eb93efc560..698d2d51cc 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -34,6 +34,9 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:amdgpu_code_gen",
"@llvm//:analysis",
"@llvm//:bit_reader",
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
index 12a8a59488..85bc58cb44 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) {
const llvm::PassInfo *PI =
llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID());
const string basename = ReplaceFilenameExtension(
- tensorflow::io::Basename(input_filename_),
- tensorflow::strings::Printf(
+ absl::string_view(tensorflow::io::Basename(input_filename_)),
+ absl::StrFormat(
"pass-%02d.before.%s.ll", i,
- (PI == nullptr ? "unknown" : PI->getPassArgument().data())));
+ absl::string_view(PI == nullptr ? "unknown"
+ : PI->getPassArgument().data())));
llvm::legacy::PassManager::add(
new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename)));
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index ff4ae1f9ef..8751e3a9c2 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -20,13 +20,15 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
@@ -54,10 +56,7 @@ limitations under the License.
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Scalar.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
@@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
<< ", " << compute_capability.second << ") ."
<< "Defaulting to libdevice for compute_" << libdevice_version;
}
- return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version,
- ".10.bc");
+ return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc");
}
// Gets the GPU name as it's known to LLVM for a given compute capability. If
@@ -138,15 +136,16 @@ static string GetSmName(std::pair<int, int> compute_capability) {
<< "Defaulting to telling LLVM that we're compiling for sm_"
<< sm_version;
}
- return tensorflow::strings::StrCat("sm_", sm_version);
+ return absl::StrCat("sm_", sm_version);
}
// Convenience function for producing a name of a temporary compilation product
// from the input filename.
string MakeNameForTempProduct(const std::string& input_filename,
- tensorflow::StringPiece extension) {
- return ReplaceFilenameExtension(
- tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension);
+ absl::string_view extension) {
+ return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename(
+ llvm_ir::AsString(input_filename))),
+ extension);
}
// Initializes LLVM passes. Uses the PassRegistry mechanism.
@@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) {
// Returns the TargetMachine, given a triple.
std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
- llvm::Triple triple, tensorflow::StringPiece cpu_name,
+ llvm::Triple triple, absl::string_view cpu_name,
const HloModuleConfig& hlo_module_config) {
std::string error;
const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
@@ -205,7 +204,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
default:
codegen_opt_level = CodeGenOpt::None;
}
- return WrapUnique(target->createTargetMachine(
+ return absl::WrapUnique(target->createTargetMachine(
triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options,
Optional<Reloc::Model>(RelocModel), Optional<CodeModel::Model>(CMModel),
codegen_opt_level));
@@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
}
// Emits the given module to a bit code file.
-void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) {
+void EmitBitcodeToFile(const Module& module, absl::string_view filename) {
std::error_code error_code;
- llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code,
+ llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
llvm::sys::fs::F_None);
if (error_code) {
LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
@@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
// get creative to add a suffix.
string module_id(llvm_ir::AsString(module->getModuleIdentifier()));
IrDumpingPassManager codegen_passes(
- ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
- "-nvptx.dummy"),
+ ReplaceFilenameExtension(
+ absl::string_view(tensorflow::io::Basename(module_id)),
+ "-nvptx.dummy"),
"", false);
codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
llvm::Triple(module->getTargetTriple())));
@@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {
- return tensorflow::errors::Internal(tensorflow::strings::StrCat(
- "Error linking libdevice from ", libdevice_path));
+ return tensorflow::errors::Internal(
+ absl::StrCat("Error linking libdevice from ", libdevice_path));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
index 54e0e140de..9654175bfa 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
namespace gpu {
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
index 9ef9bc3a50..3b2c3591d9 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
@@ -17,13 +17,13 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace {
@@ -52,14 +52,13 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
return module;
}
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension) {
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension) {
auto pos = filename.rfind('.');
- tensorflow::StringPiece stem =
- pos == tensorflow::StringPiece::npos
- ? filename
- : tensorflow::StringPiece(filename.data(), pos);
- return tensorflow::strings::StrCat(stem, ".", new_extension);
+ absl::string_view stem = pos == absl::string_view::npos
+ ? filename
+ : absl::string_view(filename.data(), pos);
+ return absl::StrCat(stem, ".", new_extension);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
index a6daeca95a..60f4926849 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace llvm {
class LLVMContext;
@@ -41,8 +41,8 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
//
// For example:
// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc"
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension);
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c62bae0628..7a43f0be54 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -48,7 +49,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// If possible, we want to pick a reduce operand of the fusion root,
// because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
- if (inst->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*inst)) {
return inst;
}
}
@@ -63,7 +64,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
auto get_element_shape = [&](const HloInstruction* element_instr) {
// Special handling of kReduce instructions -- the fusion
// applies to the first operand.
- if (element_instr->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*element_instr)) {
return element_instr->operand(0)->shape();
}
return element_instr->shape();
@@ -131,7 +132,7 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
max_rank_layout = &param->shape().layout();
}
}
- return c_all_of(params, [&](HloInstruction* param) {
+ return absl::c_all_of(params, [&](HloInstruction* param) {
return (ShapeUtil::Rank(param->shape()) < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
@@ -140,10 +141,15 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
- // We can fuse reduces and loop fusions.
- return IsInputFusibleReduction(instr) ||
- (instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
+ // We can fuse reduces and loop fusions. Elementwise instructions can be fused
+ // with any other instruction.
+ // TODO(b/112957171): This should use the same isFusible logic as
+ // instruction_fusion.
+ return instr->IsFusible() &&
+ (IsInputFusibleReduction(instr) ||
+ (instr->opcode() == HloOpcode::kFusion &&
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop) ||
+ instr->IsElementwise());
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
@@ -177,11 +183,12 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
// together too (the result being an input fusion) if we find cases where this
- // improves things.
+ // improves things. Also disable fusing standalone input-fusible reduces into
+ // loop fusions.
CHECK(instr1->opcode() == HloOpcode::kFusion);
if ((instr2->opcode() == HloOpcode::kFusion &&
instr1->fusion_kind() != instr2->fusion_kind()) ||
- (instr2->opcode() != HloOpcode::kFusion &&
+ (IsReductionToVector(*instr2) &&
instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) {
return false;
}
@@ -197,7 +204,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
// Keep a list of the instructions to fuse after making all the fusion
// decisions. We first aggressively add instructions to potential_fusion_list,
- // then filter out instructions that will be no longer fusable because of
+ // then filter out instructions that will be no longer fusible because of
// reachability change. This avoids recalculating reachability on a large set
// of instructions.
std::vector<std::pair<HloInstruction*, HloInstruction*>>
@@ -213,7 +220,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
continue;
}
if (!IsInputFusibleReduction(consumer)) {
- VLOG(3) << consumer->name() << " is not an input-fusable reduction.";
+ VLOG(3) << consumer->name() << " is not an input-fusible reduction.";
continue;
}
VLOG(3) << consumer->name()
@@ -222,8 +229,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
auto consumer_operands = consumer->operands();
for (size_t i = 0; i < consumer_operands.size(); ++i) {
HloInstruction* producer = consumer_operands[i];
- if (!producer->IsFusable()) {
- VLOG(3) << producer->name() << " is not fusable.";
+ if (!producer->IsFusible()) {
+ VLOG(3) << producer->name() << " is not fusible.";
continue;
}
const bool is_loop_fusion =
@@ -248,7 +255,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
}
// Do not fuse a producer if the other operands of the fusion are
// reachable from the producer, this would create a cycle.
- if (c_any_of(consumer_operands, [&](HloInstruction* operand) {
+ if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
@@ -263,12 +270,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
}
}
- // Filter out pairs that will be no longer fusable because of reachability
+ // Filter out pairs that will be no longer fusible because of reachability
// change.
for (auto& fusion_pair : potential_fusion_list) {
HloInstruction* producer = fusion_pair.first;
HloInstruction* consumer = fusion_pair.second;
- if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) {
+ if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 67ca5d49ee..f0b4d67ab8 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -22,7 +22,7 @@ namespace xla {
namespace gpu {
// Multi-output fusion of sibling and producer-consumer instructions for the
-// Jellyfish backend.
+// GPU backend.
class GpuMultiOutputFusion : public MultiOutputFusion {
public:
GpuMultiOutputFusion();
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 14f157a5e5..c822c94f1b 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -15,19 +15,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
+namespace op = xla::testing::opcode_matchers;
+
using MultiOutputFusionTest = HloTestBase;
const char kModulePrefix[] = R"(
@@ -47,7 +47,7 @@ const char kModulePrefix[] = R"(
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
// Fusion with reduce instruction root and a sibling reduce instruction
// sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[6400]{0} parameter(1)
mul = f32[6400]{0} multiply(p1.1, p1.1)
@@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
// Two sibling fusions with reduce instruction roots sharing the same input
// param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
// Multi-output fusion with two reduce instructions root and a sibling reduce
// instruction sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
const.1 = f32[] constant(1)
p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
@@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
// Verify that if we already have a multi-output fusion that we prefer to pick
// a reduce op from its operands for checking shape compatibility.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest,
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[6400]{0} parameter(0)
ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
@@ -256,8 +256,136 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
+ // Fusing a reduce into a loop fusion would require changing the fusion kind.
+ // That's not supported yet.
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(0)
+ reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation
+ ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(1)
+ div = f32[6400]{0} divide(p0, const.2)
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Divide()));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Exp(), op::Add()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -277,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -304,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
c0 = f32[] constant(0)
@@ -345,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -372,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
c0 = f16[] constant(0)
@@ -413,7 +541,7 @@ TEST_F(MultiOutputFusionTest,
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
mixed_input_layouts_computation {
p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 6c1eab4f8c..8e4a8e5f54 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -21,13 +21,15 @@ limitations under the License.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
@@ -85,7 +87,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cuda_libdevice_path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -140,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -156,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -203,10 +206,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
+ // CudnnConvolutionRewriter may add instructions of the form
+ // reverse(constant), which it expects will be simplified by constant
+ // folding.
+ pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -218,9 +226,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
}
{
- HloPassPipeline pipeline("layout_assignment");
+ // Run layout assignment in a separate pipeline from
+ // "post-layout-assignment" because we want everything after layout
+ // assignment to have a layout-sensitive invariant-checker, but
+ // HloPassPipeline also runs its invariant checker before any passes are
+ // run, meaning, the pipeline that contains layout assignment cannot contain
+ // a layout-sensitive verifier!
+ HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("post-layout_assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -266,17 +287,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>();
+ fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
fusion.AddPass<GpuMultiOutputFusion>();
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
+ fusion.AddPass<HloDCE>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>();
+ reduce_pipeline.AddInvariantChecker<HloVerifier>(
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -302,7 +326,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -352,9 +377,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
string vmaj_str, vmin_str, vdot_str;
if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str,
&vmin_str, &vdot_str) ||
- !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) ||
- !tensorflow::strings::safe_strto64(vmin_str, &vmin) ||
- !tensorflow::strings::safe_strto64(vdot_str, &vdot)) {
+ !absl::SimpleAtoi(vmaj_str, &vmaj) ||
+ !absl::SimpleAtoi(vmin_str, &vmin) ||
+ !absl::SimpleAtoi(vdot_str, &vdot)) {
LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path
<< " --version:\n"
<< out;
@@ -466,7 +491,7 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
tensorflow::SubProcess ptxas_info_dumper;
std::vector<string> ptxas_args = {
ptxas_path, ptx_path, "-o", cubin_path,
- tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)};
+ absl::StrCat("-arch=sm_", cc_major, cc_minor)};
if (VLOG_IS_ON(2)) {
ptxas_args.push_back("-v");
}
@@ -674,7 +699,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// Write PTX to IR dump directory, if IR dumping was requested.
if (!ir_dump_directory.empty()) {
const string ptx_outfile = tensorflow::io::JoinPath(
- ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx"));
+ ir_dump_directory, absl::StrCat(module->name(), ".ptx"));
auto status = [&] {
auto* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory));
@@ -690,7 +715,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
const std::vector<uint8> cubin =
CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor);
- auto thunk_schedule = MakeUnique<ThunkSchedule>(
+ auto thunk_schedule = absl::make_unique<ThunkSchedule>(
ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder());
VLOG(2) << "Printing the thunk schedule...";
@@ -704,7 +729,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
- profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
+ profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
}
@@ -813,7 +838,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::cuda::kCudaPlatformId,
- []() { return xla::MakeUnique<xla::gpu::NVPTXCompiler>(); });
+ []() { return absl::make_unique<xla::gpu::NVPTXCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index d4d2909f1b..08ef6ef56c 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
index 4aaf0c9e14..2fa170964e 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
index b99d998c4d..e0f3e84a4c 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
@@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
VLOG(2) << "Outfeeding from GPU complete";
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 192359f026..11dc56a64f 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -32,9 +32,7 @@ namespace gpu {
// TODO(jlebar): Also pad dots.
class PadForTensorCores : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "pad for tensor cores";
- }
+ absl::string_view name() const override { return "pad for tensor cores"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 99e7580b82..104af48c82 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,7 +29,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-using PadForTensorCoresTest = HloVerifiedTestBase;
+class PadForTensorCoresTest : public HloVerifiedTestBase {
+ public:
+ PadForTensorCoresTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b22040eee1..98cc21ccac 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
PrimitiveType element_type = input->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
PrimitiveType element_type = kernel->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
+ HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index 67e51509e4..a622e894ed 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -26,7 +26,7 @@ namespace gpu {
// padding, so that they can be lowered to cuDNN convolution.
class PadInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "pad insertion"; }
+ absl::string_view name() const override { return "pad insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 3838fee674..ca57cacb98 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
unroll_factor_(unroll_factor) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
// if (linear_index < num_elements) {
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index b82a23419d..cc7da2e73b 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
// The thread and block dimension to parallelize the loop on.
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index d3fd0544fb..cf9f102d31 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -18,15 +18,15 @@ limitations under the License.
#include <ostream>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -34,9 +34,8 @@ namespace gpu {
std::ostream& operator<<(std::ostream& out,
const LaunchDimensions& launch_dims) {
- out << tensorflow::strings::Printf("[block: %lld, thread: %lld]",
- launch_dims.block_count(),
- launch_dims.threads_per_block());
+ out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(),
+ launch_dims.threads_per_block());
return out;
}
@@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions(
}
int64 block_count = CeilOfRatio(num_elements, threads_per_block);
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << absl::StrFormat(
"Initialized the block count to ceil(# of elements / threads per "
- "block) = ceil(%lld/%lld) = %lld",
+ "block) = ceil(%d/%d) = %d",
num_elements, threads_per_block, block_count);
return LaunchDimensions(block_count, threads_per_block);
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
index 0806dd5161..5b6cf2c04d 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@@ -119,7 +119,7 @@ int ComputeStreamToAssign(
} // namespace
std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
- auto stream_assignment = MakeUnique<StreamAssignment>();
+ auto stream_assignment = absl::make_unique<StreamAssignment>();
const HloComputation& computation = *module.entry_computation();
std::unique_ptr<HloReachabilityMap> reachability =
computation.ComputeReachability();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 6f4bb0580e..091aca23e5 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace gpu {
@@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
// Pre-canned shapes.
@@ -97,7 +98,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
- i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
index 05b305ea4c..08ff52211a 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
@@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
input_layout.push_back(dnums.input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid input layout: ",
- DataLayoutString(input));
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ DataLayoutString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> filter_layout;
@@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
filter_layout.push_back(dnums.kernel_input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid filter layout: ",
- FilterLayoutString(filter));
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ FilterLayoutString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> output_layout;
@@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
output_layout.push_back(dnums.output_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid output layout: ",
- DataLayoutString(output));
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ DataLayoutString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
@@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(input, nhwc_input)) {
input_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid input layout: ",
- input.ShortDebugString());
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
FilterLayout filter_layout;
@@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(filter, nhwc_filter)) {
filter_layout = FilterLayout::kOutputYXInput;
} else {
- return tensorflow::errors::Internal("Invalid filter layout: ",
- filter.ShortDebugString());
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
DataLayout output_layout;
@@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(output, nhwc_output)) {
output_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid output layout: ",
- output.ShortDebugString());
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(input_layout, filter_layout, output_layout);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 4fad3f46cf..db4a33dc56 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -35,13 +35,13 @@ cc_library(
"requires-gpu-sm35",
],
deps = [
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -60,6 +60,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +95,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -150,6 +152,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -168,6 +171,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
index 4b8415fe91..79e77d4c4d 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/platform/logging.h"
@@ -32,15 +32,14 @@ std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) {
debug_options.add_xla_disable_hlo_passes("constant_folding");
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>(TestName(), config);
+ return absl::make_unique<HloModule>(TestName(), config);
}
void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module,
const string& pattern) {
std::unique_ptr<Executable> executable =
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
- string ptx_str =
- std::string(static_cast<GpuExecutable*>(executable.get())->ptx());
+ string ptx_str(static_cast<GpuExecutable*>(executable.get())->ptx());
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index ce69e058e6..4550f36fdf 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
index e5958165ef..a06576df7b 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index cca35316f0..15d1e269cc 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -27,13 +27,22 @@ namespace {
class GpuKernelTilingTest : public GpuCodegenTest {
protected:
- GpuKernelTilingTest() {
+ GpuKernelTilingTest() {}
+
+ // Most tests in this file want to skip layout assignment, but a few need it
+ // enabled.
+ HloModuleConfig ConfigWithLayoutAssignment() {
+ return GetModuleConfigForTest();
+ }
+
+ HloModuleConfig ConfigWithoutLayoutAssignment() {
+ HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
- config_.set_debug_options(debug_options);
// Disable layout_assignment to use the preassigned layouts.
- debug_options.add_xla_disable_hlo_passes("layout_assignment");
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ config.set_debug_options(debug_options);
+ return config;
}
- HloModuleConfig config_;
};
TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
@@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ //
+ // We must enable layout assignment in order for this test to work correctly.
+ // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
+ // which respects the module's entry computation layout. But if we don't run
+ // layout assignment...well, nobody else adds the copy back.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
})";
- // Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ // Check that a call to llvm.nvvm.barrier0 is not generated. As in
+ // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
+ // here.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest,
})";
// Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
index 6c9ae7bada..6a9ecd9dae 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
index c42e5704a4..15198865bd 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
index 9622936306..0f2d5568ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
@@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_max_kernel_unroll_factor(2);
+ // Disable layout assignment for this test. Layout assignment does not expect
+ // fusions to be present, and so it does the wrong thing.
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
const char *const kMultiOutputFusionModule = R"(
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
index bdb062837c..141f321938 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
@@ -144,16 +144,15 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn(
string ThunkSchedule::ToString() const {
string result = "Total order:\n";
for (Thunk* thunk : thunk_total_order_) {
- tensorflow::strings::StrAppend(&result, "\t",
- thunk->hlo_instruction()->ToString(), "\n");
+ absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n");
}
- tensorflow::strings::StrAppend(&result, "Dependencies:\n");
+ absl::StrAppend(&result, "Dependencies:\n");
for (const auto& entry : depends_on_) {
const Thunk* dependent = entry.first;
for (const Thunk* dependency : entry.second) {
- tensorflow::strings::StrAppend(
- &result, "\t", dependent->hlo_instruction()->name(), " depends on ",
- dependency->hlo_instruction()->name(), "\n");
+ absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(),
+ " depends on ", dependency->hlo_instruction()->name(),
+ "\n");
}
}
return result;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index 8579b1545f..989b542ff4 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
@@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
auto size = tuple_element_buffers_.size();
- auto tuple_element_buffer_addresses = MakeUnique<void*[]>(size);
+ auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
for (int i = 0; i != size; ++i) {
tuple_element_buffer_addresses[i] =
buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque();
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index d81d87e7dc..c4754fe378 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -34,9 +34,9 @@ WhileThunk::WhileThunk(
// and body_thunk_sequence_ constructors because these SequentialThunks
// are logically "part of" this WhileThunk, and shouldn't be profiled
// separately from it.
- condition_thunk_sequence_(MakeUnique<SequentialThunk>(
+ condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*condition_thunk_sequence), nullptr)),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
@@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
if (!block_status.ok()) {
return InternalError(
"Failed to complete all kernels launched on stream %p: %s", stream,
- block_status.error_message().c_str());
+ block_status.error_message());
}
if (!condition_result) {
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index c5f3906356..40183de96e 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier;
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index aa89567ee8..a2be89511b 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -22,9 +22,10 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@@ -43,8 +43,7 @@ namespace {
// Adds a computation to the given HLO module which adds a scalar constant to
// its parameter and returns the result.
HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
- auto builder =
- HloComputation::Builder(tensorflow::strings::StrCat("add_", addend));
+ auto builder = HloComputation::Builder(absl::StrCat("add_", addend));
auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "x_value"));
auto half = builder.AddInstruction(
@@ -84,7 +83,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation,
// the module.
std::unique_ptr<HloModule> MakeBigGraph() {
HloModuleConfig config;
- auto module = MakeUnique<HloModule>("BigGraph", config);
+ auto module = absl::make_unique<HloModule>("BigGraph", config);
auto builder = HloComputation::Builder("TestBigGraphvizGraph");
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 4005fc0d11..38c3982ebf 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -45,7 +46,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
// bound, by minimizing the liveness of sub-computations.
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
module_sequence, *points_to_analysis, size_function));
return result.heap_size;
}
@@ -60,9 +61,10 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function,
- HeapSimulator::Options(), memory_by_computation));
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
+ computation, sequence, points_to_analysis,
+ size_function, HeapSimulator::Options(),
+ memory_by_computation));
return result.heap_size;
}
@@ -142,7 +144,7 @@ Status HeapSimulator::RunComputation(
}
} else {
// A GetTupleElement doesn't need to keep all of its operand's buffers
- // alive. It only needs the buffers that relate to the element its
+ // alive. It only needs the buffers that relate to the element it's
// extracting, and the tuple it's extracting from, but not the buffers
// for the other elements.
for (const BufferValue* buffer : points_to.element({})) {
@@ -275,13 +277,13 @@ Status HeapSimulator::RunComputation(
*memory_by_computation_);
}
- // If the whole module is sequential, we can save memory by running the
- // heap-simulation for sub-computations inline. E.g. the buffers for the
- // condition and body of a kWhile instruction are only live for the duration
- // of the instruction itself.
+ // If all computations in the module have been scheduled, we can save memory
+ // by running the heap-simulation for sub-computations inline. E.g. the
+ // buffers for the condition and body of a kWhile instruction are only live
+ // for the duration of the instruction itself.
//
// The order that the sub-computations are simulated does not affect
- // correctness; since the whole module is sequential, we know that the
+ // correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
if (module_sequence_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
@@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator(
const SequentialHloOrdering::HloModuleSequence* module_sequence,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
- : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
+ : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
@@ -378,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer,
allocated_buffers_.insert(buffer);
const int64 size = size_fn_(*buffer);
- algorithm_->Alloc(buffer, size);
- no_fragmentation_stats_->Alloc(buffer, size);
-
+ const HloInstruction* instruction_to_calc_aliasing =
+ memory_by_computation_ == nullptr ? nullptr : instruction;
+ algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing);
+ no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing);
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
nullptr);
}
@@ -518,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ if (instruction == nullptr ||
+ (instruction->opcode() != HloOpcode::kWhile &&
+ instruction->opcode() != HloOpcode::kCall &&
+ instruction->opcode() != HloOpcode::kConditional)) {
+ Alloc(buffer, size);
+ }
+}
+
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 811a6042df..af05bedee7 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -36,6 +36,7 @@ namespace xla {
// Forward declare classes defined below.
class HeapAlgorithm;
+class NoFragmentationStatsHeap;
// HeapSimulator assigns buffer offsets by running a simulation of a regular
// memory heap with Alloc and Free calls. It only works for completely
@@ -161,7 +162,10 @@ class HeapSimulator {
const HloInstruction* instruction,
const BufferValue* shared_with_canonical);
- const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_;
+ // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
+ // in which case we are calculating the same allocs/frees twice in the
+ // simulation.
+ const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
@@ -216,6 +220,21 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
+ // NoFragmentationStatsHeap overrides this method.
+ virtual void Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) {
+ Alloc(buffer, size);
+ }
+
+ // Takes memory usage of subcomputations into account when calculating the
+ // memory usage of a computation. Currently, we don't handle buffer aliasing
+ // between computations entirely correctly. We are careful to not double count
+ // for the output buffers of whiles/conds/calls. But we don't take into
+ // account other aliases, such as for the while init. A more thorough solution
+ // would require something like BufferAssignment::BuildColocatedBufferSets.
+ // TODO(b/65835246):
+ // Since TuplePointsToAnalysis is being replaced with a module-aware alias
+ // analysis, it's not worth making major changes to HeapSimulator now.
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
@@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
void Alloc(const BufferValue* buffer, int64 size) override;
+ void Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) override;
+
void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index b41dc66fe9..5f85f14565 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -137,7 +138,7 @@ class HeapSimulatorTracker {
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@@ -146,8 +147,8 @@ class HeapSimulatorTracker {
// the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
// buffer id, for determinism in the tests.
auto zero_size = [](const BufferValue& buffer) { return 0; };
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(
std::move(algorithm), *module_->entry_computation(),
instruction_sequence, *points_to_analysis_, zero_size)
@@ -156,7 +157,7 @@ class HeapSimulatorTracker {
explicit HeapSimulatorTracker(const string& name) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
}
// Similar to the single entry computation constructor above, but runs the
@@ -182,8 +183,8 @@ class HeapSimulatorTracker {
auto size_fn = [&reverse_position](const BufferValue& buffer) {
return reverse_position[buffer.instruction()];
};
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(std::move(algorithm), *module_,
module_sequence, *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
@@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
+ buffers_.emplace_back(
+ absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
@@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
TEST_F(DecreasingSizeRunsHeapTest, Empty) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Finish();
EXPECT_EQ(call_sequence, CallSequence({
{kFinish, nullptr},
@@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) {
TEST_F(DecreasingSizeRunsHeapTest, Simple) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Alloc(buffer_c_, 30);
@@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) {
TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Free(buffer_b_, 20);
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index fa218657fe..58b7af93eb 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 51
+// Next ID: 53
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -46,6 +46,8 @@ message HloInstructionProto {
reserved "control_predecessor_names";
reserved 6;
reserved "called_computation_names";
+ reserved 44;
+ reserved "replica_group_ids";
string name = 1;
string opcode = 2;
@@ -158,9 +160,6 @@ message HloInstructionProto {
string backend_config = 43;
// Cross replica op fields.
- // TODO(b/112107579): remove replica_group_ids field and always use
- // replica_groups.
- repeated int64 replica_group_ids = 44;
repeated ReplicaGroup replica_groups = 49;
int64 all_reduce_id = 45;
string cross_replica_sum_barrier = 46;
@@ -171,6 +170,12 @@ message HloInstructionProto {
bool is_host_transfer = 47;
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
+
+ // Precision configuration for the instruction. Has backend-specific meaning.
+ xla.PrecisionConfigProto precision_config = 51;
+
+ // Collective permute field.
+ repeated SourceTarget source_target_pairs = 52;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index e8a4b034b4..0986da65cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,15 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
// Data structure used to construct the alias analysis. Thrown away after alias
// analysis is complete. This data structure keeps track of which sets of
@@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const {
}
string HloAliasAnalysis::ToString() const {
- string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
+ string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
for (const HloComputation* computation : module_->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
@@ -457,7 +455,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
- auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
+ auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false,
@@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
if (ordering.MayInterfere(*values[i - 1], *values[i],
dataflow_analysis())) {
VLOG(1) << "In buffer " << buffer.id() << " containing values:\n "
- << Join(values, ", ",
- [](string* out, const HloValue* value) {
- StrAppend(out, value->ToShortString());
- })
+ << absl::StrJoin(values, ", ",
+ [](string* out, const HloValue* value) {
+ StrAppend(out, value->ToShortString());
+ })
<< "\nValue " << values[i - 1]->ToShortString()
<< " may interfere with value " << values[i]->ToShortString();
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index e16413f361..6c11a073b7 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -27,15 +29,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
-
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {
@@ -59,10 +56,11 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
}
string HloBuffer::ToString() const {
- return StrCat("HloBuffer ", id_, ", values: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return absl::StrCat(
+ "HloBuffer ", id_, ", values: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 441288da1a..c2d0673f49 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -23,9 +23,13 @@ limitations under the License.
#include <set>
#include <sstream>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -36,13 +40,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root_instruction) {
@@ -56,8 +58,8 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root =
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
- return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, fusion_instruction_));
+ return absl::WrapUnique(new HloComputation(
+ name_, parameter_count, &instructions_, root, fusion_instruction_));
}
HloComputation::HloComputation(
@@ -135,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) {
}
string after_param = original_name.substr(index + param_underscore.size());
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_param, &numeric_suffix)) {
return StrCat(original_name.substr(0, index + param_underscore.size()),
new_param_no);
}
@@ -317,11 +319,12 @@ void ComputeComputationPostOrder(
}
}
-enum State { kVisiting, kVisited };
+} // namespace
-void ComputeInstructionPostOrder(
+void HloComputation::ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
@@ -354,16 +357,71 @@ void ComputeInstructionPostOrder(
for (HloInstruction* op : current->control_predecessors()) {
dfs_stack.emplace_back(op);
}
+
+ // Add inputs for send->recv_done dependencies and cross-replica-sum
+ // dependencies.
+ switch (current->opcode()) {
+ case HloOpcode::kRecvDone: {
+ auto it = channel_dependency_map.find(current->channel_id());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = current->all_reduce_id();
+ if (all_reduce_id) {
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
}
}
-} // namespace
+HloComputation::ChannelDependencyMap
+HloComputation::ComputeChannelDependencies() const {
+ ChannelDependencyMap channel_dependency_map;
+ for (const auto& instruction : instructions_) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ channel_dependency_map[instruction->channel_id()].push_back(
+ instruction.get());
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = instruction->all_reduce_id();
+ if (all_reduce_id) {
+ auto& dependencies = channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(instruction->operands(),
+ std::back_inserter(dependencies));
+ absl::c_copy(instruction->control_predecessors(),
+ std::back_inserter(dependencies));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ return channel_dependency_map;
+}
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatMap<HloInstruction*, State> visited;
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -371,7 +429,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
+ ComputeInstructionPostOrder(channel_dependency_map, &post_order,
+ instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
@@ -493,9 +552,9 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
+ &instructions, root,
+ /*fusion_instruction=*/nullptr));
}
void HloComputation::FuseInstructionsInto(
@@ -566,16 +625,15 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
- instruction->name().c_str(), name().c_str());
+ instruction->name(), name());
}
if (indices_to_copy != nullptr &&
!ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
return FailedPrecondition(
"Can't deep copy instruction %s: given shape tree of indices to copy "
"has incompatible shapes: %s vs. %s",
- instruction->name().c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- ShapeUtil::HumanString(indices_to_copy->shape()).c_str());
+ instruction->name(), ShapeUtil::HumanString(instruction->shape()),
+ ShapeUtil::HumanString(indices_to_copy->shape()));
}
ShapeIndex index;
@@ -605,7 +663,7 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
- instruction->name().c_str(), name().c_str());
+ instruction->name(), name());
}
ShapeIndex index;
return DeepCopyHelper(instruction, &index, copy_leaf);
@@ -624,6 +682,9 @@ ProgramShape HloComputation::ComputeProgramShape() const {
}
bool HloComputation::operator==(const HloComputation& other) const {
+ if (this == &other) {
+ return true;
+ }
std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
[&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
@@ -674,13 +735,37 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
const auto& all = MakeInstructionPostOrder();
- auto result = MakeUnique<HloReachabilityMap>(all);
+ auto result = absl::make_unique<HloReachabilityMap>(all);
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kRecvDone: {
+ auto it = channel_dependency_map.find(hlo->channel_id());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = hlo->all_reduce_id();
+ if (all_reduce_id) {
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
@@ -723,11 +808,10 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
}
}
VLOG(3) << "Unreachable roots:"
- << tensorflow::str_util::Join(
- unreachable_roots, "\n\t",
- [](string* out, const HloInstruction* hlo) {
- tensorflow::strings::StrAppend(out, hlo->ToString());
- });
+ << absl::StrJoin(unreachable_roots, "\n\t",
+ [](string* out, const HloInstruction* hlo) {
+ absl::StrAppend(out, hlo->ToString());
+ });
return unreachable_roots;
}
@@ -829,7 +913,7 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
HloCloneContext* context, const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
- context_ptr = MakeUnique<HloCloneContext>(parent(), suffix);
+ context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
context = context_ptr.get();
}
@@ -898,12 +982,11 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
name_ = name_uniquer->GetUniqueName(name_);
}
-HloInstruction* HloComputation::GetInstructionWithName(
- tensorflow::StringPiece name) {
+HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
auto instructions_in_computation = instructions();
- auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) {
- return instr->name() == name;
- });
+ auto it = absl::c_find_if(
+ instructions_in_computation,
+ [&](HloInstruction* instr) { return instr->name() == name; });
return it == instructions_in_computation.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 49ed65910f..59016624f7 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -367,7 +367,7 @@ class HloComputation {
// Returns the instruction in this computation that has name `name`. Returns
// null if there is no such computation.
- HloInstruction* GetInstructionWithName(tensorflow::StringPiece name);
+ HloInstruction* GetInstructionWithName(absl::string_view name);
int64 unique_id() const { return unique_id_; }
@@ -399,6 +399,20 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
+ // Returns a map from channel-id to directed dependencies of the channel
+ // instructions. For send&recv pairs it means the send instruction and for
+ // cross-replica-sum the union of the dependencies for all participating
+ // instructions.
+ using ChannelDependencyMap =
+ tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ ChannelDependencyMap ComputeChannelDependencies() const;
+
+ enum VisitState { kVisiting, kVisited };
+ void ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
+
string name_;
int64 unique_id_;
HloInstruction* root_instruction_;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index e4c5470331..f7ed1b0316 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) {
EXPECT_EQ(computation->ToString(options), expected_computation2);
}
-} // namespace
+TEST_F(HloComputationTest, ChannelReachability) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
+ HloComputation::Builder builder("ChannelReachability");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto send =
+ builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto recv =
+ builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build(recv_done));
+ auto reachability = computation->ComputeReachability();
+ EXPECT_TRUE(reachability->IsReachable(param, recv_done));
+ EXPECT_FALSE(reachability->IsReachable(send, recv));
+ EXPECT_FALSE(reachability->IsReachable(send_done, recv));
+}
+
+} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 7229031c0c..2ed645c3ae 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -38,7 +39,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Limit the constant folding to 0 iterations to skip folding loops. This
// retains the behavior from before while loop support in HloEvaluator and may
// be revised.
- auto evaluator = MakeUnique<HloEvaluator>(/*max_loop_iterations=*/0);
+ auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
@@ -51,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
- // Skip Constant, Parameter, Reduce, and AfterAll operation.
- // TODO(b/35975797): Enable Reduce operation once arbitrary computation
- // are supported by the evaluator.
+ // Skip Constant, Parameter, and AfterAll operation.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
// operand in which case constant folding will be impossible and this
@@ -61,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 331480bd02..4557983a9c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -25,7 +25,7 @@ namespace xla {
// computation on constants.
class HloConstantFolding : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "constant_folding"; }
+ absl::string_view name() const override { return "constant_folding"; }
// Run constant folding operations on the given module. Returns whether the
// module was changed (constant expressions folded).
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 64a42c1efc..7cd1481a8a 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
EXPECT_TRUE(matched);
}
+const char* const kConstantFoldReduce = R"(
+ HloModule ConstantFoldReduce
+
+ add {
+ a = s32[] parameter(0)
+ b = s32[] parameter(1)
+ ROOT add = s32[] add(a, b)
+ }
+
+ ENTRY r {
+ x = s32[3] constant({1, 2, 3})
+ init = s32[] constant(0)
+ ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
+ })";
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
+
+ EXPECT_EQ(6, module->entry_computation()
+ ->root_instruction()
+ ->literal()
+ .GetFirstElement<int32>());
+}
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloInstruction* add = module->computations().begin()->root_instruction();
+ LayoutUtil::ClearLayout(add->mutable_shape());
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_FALSE(result);
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 1bbb0ff08e..0e12a1ee03 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -258,10 +258,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
- return Status::OK();
-}
-
Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
// Compute properties of the mapped function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
@@ -544,15 +540,10 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
}
Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
- // TODO(b/110096724): Compute correct cost here.
- double flops = 0.0;
- ShapeUtil::ForEachSubshape(hlo->shape(),
- [&](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsArray(subshape)) {
- flops += ShapeUtil::ElementsIn(subshape);
- }
- });
- current_properties_[kFlopsKey] = flops;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 193a04bea0..c6a2007904 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -72,9 +72,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override;
+ Status HandleCollectivePermute(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
- Status HandleHostCompute(const HloInstruction* host_compute) override;
Status HandleRng(const HloInstruction* random) override;
Status HandleReverse(const HloInstruction* reverse) override;
Status HandleSort(const HloInstruction* sort) override;
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 858992a326..131846794d 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -14,15 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
+using absl::StrCat;
using tensorflow::gtl::ArraySlice;
-using tensorflow::strings::StrCat;
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
@@ -149,13 +151,13 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
- CHECK(c_all_of(operands, [&](HloInstruction* instr) {
+ CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
return instr->parent() == computation;
}));
std::vector<const Shape*> operand_shapes;
- c_transform(operands, std::back_inserter(operand_shapes),
- [](HloInstruction* instr) { return &instr->shape(); });
+ absl::c_transform(operands, std::back_inserter(operand_shapes),
+ [](HloInstruction* instr) { return &instr->shape(); });
TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
operand_shapes, dimension));
@@ -228,7 +230,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
const Shape& operand_shape = operand->shape();
new_shape_dims.reserve(n + operand_shape.dimensions_size());
new_shape_dims.insert(new_shape_dims.begin(), n, 1);
- c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
+ absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
return MakeReshapeHlo(new_shape_dims, operand);
}
@@ -240,7 +242,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
std::vector<int64> expanded_shape_dim_bounds;
expanded_shape_dim_bounds.reserve(expanded_dims.size() +
operand->shape().dimensions_size() - 1);
- c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
+ absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
std::copy(operand->shape().dimensions().begin() + 1,
operand->shape().dimensions().end(),
std::back_inserter(expanded_shape_dim_bounds));
@@ -251,7 +253,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
ArraySlice<int64> dims_to_elide) {
- CHECK(c_is_sorted(dims_to_elide));
+ CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
// First accumulate in reverse
@@ -268,7 +270,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
}
}
- c_reverse(new_shape_dim_bounds);
+ absl::c_reverse(new_shape_dim_bounds);
Shape output_shape =
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
return MakeReshapeHlo(output_shape, operand);
@@ -276,7 +278,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
StatusOr<HloInstruction*> InsertDegenerateDims(
HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
- CHECK(c_is_sorted(dims_to_insert));
+ CHECK(absl::c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
int64 output_shape_rank =
@@ -318,7 +320,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
*padding_config.add_dimensions() = padding_config_dim;
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
+ HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(operand->shape().element_type()))));
return MakePadHlo(operand, zero, padding_config);
}
@@ -328,15 +330,15 @@ StatusOr<HloInstruction*> BroadcastZeros(
ArraySlice<int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name) {
- HloComputation::Builder b{std::string(name)};
+ absl::string_view name) {
+ HloComputation::Builder b{string(name)};
int64 param_idx = 0;
for (const Shape* param_shape : domain) {
b.AddInstruction(HloInstruction::CreateParameter(
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 5ff8946fb0..1bc6d09b45 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -177,7 +177,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name);
+ absl::string_view name);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 60d3e71757..a8de285d16 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -28,7 +28,7 @@ using tensorflow::gtl::ArraySlice;
class HloCreationUtilsTest : public HloTestBase {
protected:
- static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
+ std::unique_ptr<HloModule> CreateModuleWithProgramShape(
PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
ArraySlice<int64> output_shape_dims, HloInstruction** param,
HloComputation** entry_computation) {
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 06484f4012..cb367adf5e 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) {
for (auto operand : instruction->operands()) {
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
}
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
+ }
return hash;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index 5e2b348bdd..a28c03599a 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface {
: is_layout_sensitive_(is_layout_sensitive),
only_fusion_computations_(only_fusion_computations) {}
~HloCSE() override = default;
- tensorflow::StringPiece name() const override { return "cse"; }
+ absl::string_view name() const override { return "cse"; }
// Run CSE on the given module. Returns whether the module was changed (common
// subexpressions were found and eliminated).
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 90fbaa37c5..406d712ec6 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index bbfb0c253f..3376d170e6 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,8 +19,10 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -29,8 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -78,8 +78,8 @@ bool MultiDynamicSliceUseShareSameIndices(
} // namespace
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
HloDataflowAnalysis::HloDataflowAnalysis(
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
@@ -93,7 +93,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
tensorflow::gtl::FlatSet<const HloInstruction*> visited;
- tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack;
+ absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(inst);
while (!stack.empty()) {
const HloInstruction* current = stack.back();
@@ -837,7 +837,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
return Unimplemented(
"Computation %s is called in both a parallel (eg, kMap) and "
"sequential (eg, kCall) context",
- computation->name().c_str());
+ computation->name());
}
if (call_graph_node.caller_callsites().empty() ||
call_graph_node.context() == CallContext::kParallel) {
@@ -886,7 +886,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis(
+ auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
@@ -976,28 +976,22 @@ Status HloDataflowAnalysis::Verify() const {
bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
const HloInstruction* operand, const ShapeIndex& index,
const HloInstruction* user) const {
- CHECK(user->IsUserOf(operand))
- << "user: " << user->ToString() << " operand: " << operand->ToString();
- if (user->opcode() == HloOpcode::kFusion &&
- user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
- // Find fusion parameter associated with 'operand'.
- HloInstruction* fusion_param =
- user->fused_parameter(user->operand_index(operand));
- // Iterate through all users of all uses of the fusion parameter value.
- // Return false if any uses are detected, returns true otherwise.
- const HloValue& value = GetValueDefinedAt(fusion_param, index);
- return value.uses().empty();
- } else {
- // Return false if no value at 'operand' and 'index' is used at 'user'.
- for (const HloValue* value : GetValueSet(operand, index).values()) {
- for (const HloUse& use : value->uses()) {
- if (use.instruction == user) {
- return false;
+ // Return false if no value at 'operand' and 'index' is used at 'user'.
+ for (const HloValue* value : GetValueSet(operand, index).values()) {
+ for (const HloUse& use : value->uses()) {
+ if (use.instruction == user) {
+ if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ HloInstruction* fusion_param =
+ user->fused_parameter(use.operand_number);
+ const HloValue& value =
+ GetValueDefinedAt(fusion_param, use.operand_index);
+ return value.uses().empty();
}
+ return false;
}
}
}
-
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index f4abc7a7c7..a1678d4943 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -138,7 +138,8 @@ class HloDataflowAnalysis {
// Returns true if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns false otherwise.
//
- // REQUIRES: 'operand' is an operand of 'user'.
+ // 'operand' does not have to be an operand of 'user'. This can be the case
+ // with indirect uses.
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user) const;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4755c4a0cf..d1a96c10f8 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
}
+// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
+// parameter tuple.
+TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto t0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
+ auto t1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
+ // Swap the tuple elements.
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
+
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction never uses tuple element 0, but does use element 1.
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
+ // The same holds for the parameter tuple, except that the tuple elements are
+ // swapped in 'tuple'.
+ EXPECT_TRUE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
+ EXPECT_FALSE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
+}
+
class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 4e244494d6..1fe69b1395 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -36,7 +36,7 @@ namespace xla {
class HloDCE : public HloPassInterface {
public:
~HloDCE() override {}
- tensorflow::StringPiece name() const override { return "dce"; }
+ absl::string_view name() const override { return "dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 26e3736e01..3b5cde2996 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
index 78955db0da..72185698c9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();
private:
- // Inserts a kDomain instruction between parent and operand, in case
- // the attribute (ie, sharding) values change between instruction and operand.
- // Returns the newly inserted kDomain instruction, or nullptr if no kDomain
- // instruction was necessary.
- StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
- HloInstruction* parent,
- HloInstruction* operand);
-
HloModule* module_;
HloDomainIsolator* isolator_;
};
-StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
- HloInstruction* instruction, HloInstruction* parent,
- HloInstruction* operand) {
- HloInstruction* domain = nullptr;
- std::unique_ptr<HloInstruction> domain_instruction =
- isolator_->creator_(instruction, operand);
- if (domain_instruction != nullptr) {
- domain = operand->parent()->AddInstruction(std::move(domain_instruction));
- TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain));
- }
- return domain;
-}
-
StatusOr<bool> HloDomainIsolator::RunContext::Run() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
@@ -71,16 +50,16 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
// When applying multiple domains, we could end up stacking more than
// one in one edge, so here we want to build the effective
// (kDomain-less) instruction->operand edge.
- HloInstruction* parent = instruction;
- while (operand->opcode() == HloOpcode::kDomain) {
- parent = operand;
- operand = operand->mutable_operand(0);
+ HloInstruction* root = operand;
+ while (root->opcode() == HloOpcode::kDomain) {
+ root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
- TF_ASSIGN_OR_RETURN(HloInstruction * domain,
- CreateDomain(instruction, parent, operand));
+ HloInstruction* domain =
+ isolator_->creator_(instruction, root, operand);
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
++added_domains;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index eded3e78ee..d36631fc2f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -34,14 +34,16 @@ class HloDomainIsolator : public HloPassInterface {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
- // second HloInstruction argument).
+ // third HloInstruction argument) if the interesting attribute of the
+ // instruction differes from the attribute of the root (the second
+ // HloInstruction argument).
// Returns nullptr in case no domain separation is necessary.
- using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
- HloInstruction*, HloInstruction*)>;
+ using DomainCreator = std::function<HloInstruction*(
+ HloInstruction*, HloInstruction*, HloInstruction*)>;
explicit HloDomainIsolator(DomainCreator creator);
- tensorflow::StringPiece name() const override { return "domain_isolator"; }
+ absl::string_view name() const override { return "domain_isolator"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 9e096320db..8b2846e0c2 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,14 +26,14 @@ namespace xla {
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloComputation* computation, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
return std::move(domain_map);
}
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloModule* module, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
for (HloComputation* computation : module->computations()) {
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
}
@@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
// both sides.
for (HloInstruction* operand : instruction->unique_operands()) {
if (IsDomainInstruction(operand)) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(operand);
domain->exit_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
if (instruction == instruction->parent()->root_instruction()) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
@@ -71,6 +72,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
}
Status HloDomainMap::Populate(HloComputation* computation) {
+ InstructionOrderMap instructions_post_order;
+ int64 count = 0;
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
+ instructions_post_order.insert(std::make_pair(instruction, count++));
+ }
for (HloInstruction* instruction : computation->instructions()) {
if (IsDomainInstruction(instruction)) {
// If this is a kDomain of the kind we are currently processing, check
@@ -84,7 +90,7 @@ Status HloDomainMap::Populate(HloComputation* computation) {
continue;
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
- CreateDomain(instruction));
+ CreateDomain(instruction, instructions_post_order));
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
return Status::OK();
@@ -142,10 +148,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
}
StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
- HloInstruction* instruction) const {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ HloInstruction* instruction,
+ const InstructionOrderMap& instructions_order) const {
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
- domain->instructions = MakeNonDomainInstructions(domain->reach_set);
+ domain->instructions =
+ MakeNonDomainInstructions(domain->reach_set, instructions_order);
return std::move(domain);
}
@@ -167,7 +175,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
/* static */ std::vector<HloInstruction*>
HloDomainMap::MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set) {
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const InstructionOrderMap& instructions_order) {
std::vector<HloInstruction*> instructions;
instructions.reserve(instruction_set.size());
for (HloInstruction* instruction : instruction_set) {
@@ -175,9 +184,10 @@ HloDomainMap::MakeNonDomainInstructions(
instructions.push_back(instruction);
}
}
+ // sort instructions according to instructions_order
std::sort(instructions.begin(), instructions.end(),
- [](HloInstruction* a, HloInstruction* b) {
- return a->unique_id() < b->unique_id();
+ [&instructions_order](HloInstruction* a, HloInstruction* b) {
+ return instructions_order.at(a) < instructions_order.at(b);
});
return instructions;
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 1ca7159725..633109249a 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -70,6 +70,11 @@ class HloDomainMap {
int64 GetDomainId(HloInstruction* instruction) const;
private:
+ // Map used for representing instruction ordering, i.e.
+ // order_map[a] < order_map[b] means a must be ordered before b.
+ using InstructionOrderMap =
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64>;
+
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
// Check if the kDomain instruction is facing (via its operand link) another
@@ -95,12 +100,14 @@ class HloDomainMap {
// Creates a domain data structure using the ExpandDomain() API.
StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain(
- HloInstruction* instruction) const;
+ HloInstruction* instruction,
+ const InstructionOrderMap& instructions_order) const;
// Out of an instruction set, returns a vector of all the ones which are not
// a kDomain kind.
static std::vector<HloInstruction*> MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set);
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const InstructionOrderMap& instructions_order);
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index f855f2a1fc..6c142ee474 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -44,7 +44,10 @@ class DomainMetadata {
// two domains of different kind intersect each other.
tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
- // The same instructions in reach_set, but purged from kDomain instructions.
+ // The same instructions in reach_set, but purged from kDomain instructions
+ // and ordered according to their computation graph post-order, i.e.
+ // if instructions[pos_a] depends on instructions[pos_b], then pos_a >
+ // pos_b.
std::vector<HloInstruction*> instructions;
// If we consider a graph edge as an arrow oriented from the operand to the
@@ -63,7 +66,7 @@ class DomainMetadata {
// Returns the metadata type. A unique identifier which describes the real
// metadata type.
- virtual tensorflow::StringPiece Kind() const = 0;
+ virtual absl::string_view Kind() const = 0;
// Compares the metadata object with another one and returns true if the
// two matches.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index c859e05f02..97bc8ef604 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface {
// instructions in it with the same attributes (ie, sharding), a normalizer
// function is tasked at applying attribute normalization on the instructions
// within such domain.
- HloDomainRemover(tensorflow::StringPiece kind,
+ HloDomainRemover(absl::string_view kind,
std::function<Status(const DomainMetadata::Domain&,
const DomainMetadata* metadata)>
normalizer)
- : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {}
+ : kind_(kind), normalizer_(std::move(normalizer)) {}
- tensorflow::StringPiece name() const override { return "domain_remover"; }
+ absl::string_view name() const override { return "domain_remover"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 70271be304..c8e0a9e289 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -28,6 +29,11 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
+ public:
+ HloDomainTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
@@ -45,9 +51,8 @@ class HloDomainTest : public HloVerifiedTestBase {
// Checks whether there is a kDomain instruction in the edge between the
// instruction and the operand.
- bool HasDomainEdge(HloModule* module,
- tensorflow::StringPiece instruction_name,
- tensorflow::StringPiece operand_name) {
+ bool HasDomainEdge(HloModule* module, absl::string_view instruction_name,
+ absl::string_view operand_name) {
HloInstruction* instruction = FindInstruction(module, instruction_name);
HloInstruction* operand = FindInstruction(module, operand_name);
CHECK_NE(instruction, nullptr);
@@ -65,7 +70,7 @@ class HloDomainTest : public HloVerifiedTestBase {
return false;
}
- StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) {
+ StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
ParseAndVerifyModule(hlo_string, config);
@@ -80,10 +85,10 @@ class OpNameMetadata : public DomainMetadata {
explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
std::unique_ptr<DomainMetadata> Clone() const override {
- return MakeUnique<OpNameMetadata>(opname_);
+ return absl::make_unique<OpNameMetadata>(opname_);
}
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override {
const OpNameMetadata* other_ptr =
@@ -97,25 +102,26 @@ class OpNameMetadata : public DomainMetadata {
string ToString() const override { return opname_; }
- static tensorflow::StringPiece KindName() { return "opname"; }
+ static absl::string_view KindName() { return "opname"; }
private:
string opname_;
};
// Creator function for OpNameMetadata domains.
-std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
- HloInstruction* operand) {
- if (instruction->metadata().op_name() == operand->metadata().op_name()) {
+HloInstruction* OpNameDomainCreator(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
+ if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<OpNameMetadata>(operand->metadata().op_name());
+ absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<OpNameMetadata>(instruction->metadata().op_name());
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
+ absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
+ return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata)));
}
Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
@@ -142,7 +148,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -184,7 +190,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(!isolator_changed);
}
@@ -211,7 +217,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -248,7 +254,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_FALSE(isolator_changed);
}
@@ -302,7 +308,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator sharding_isolator(CreateShardingDomain);
+ HloDomainIsolator sharding_isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
sharding_isolator.Run(module));
EXPECT_TRUE(sharding_isolator_changed);
@@ -344,7 +350,8 @@ ENTRY entry {
token = token[] after-all()
infeed = ((f32[4], f32[4]), token[]) infeed(token),
sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
- infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0
+ infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0,
+ sharding={{maximal device=1}, {maximal device=0}}
gte0 = f32[4] get-tuple-element(infeed.data), index=0
gte1 = f32[4] get-tuple-element(infeed.data), index=1
copy0 = f32[4] copy(gte0)
@@ -356,7 +363,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -378,11 +385,8 @@ ENTRY entry {
// \ /
// TUPLE
// |
- HloInstruction* infeed = FindInstruction(module, "infeed");
- ASSERT_NE(infeed, nullptr);
- HloInstruction* infeed_data =
- infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0));
+ HloInstruction* infeed_data = FindInstruction(module, "infeed.data");
+ ASSERT_NE(infeed_data, nullptr);
auto infeed_data_users = infeed_data->users();
HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
@@ -445,7 +449,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -474,8 +478,8 @@ ENTRY entry {
TEST_F(HloDomainTest, DumpParseNullSharding) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {});
- auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr);
- auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr);
+ auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr);
+ auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr);
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
@@ -490,6 +494,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) {
ASSERT_TRUE(ParseModule(hlo_string).status().ok());
}
+// Tuple inputs are domain instructions.
TEST_F(HloDomainTest, DomainTuple) {
const char* const hlo_string = R"(
HloModule Module
@@ -497,14 +502,15 @@ HloModule Module
ENTRY entry {
p0 = f32[4] parameter(0), sharding={maximal device=0}
cst = u32[] constant(0), sharding={maximal device=1}
- tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}}
+ tpl = (u32[], f32[4]) tuple(cst, p0),
+ sharding={{maximal device=1}, {maximal device=0}}
ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -523,5 +529,168 @@ ENTRY entry {
tpl->sharding());
}
+TEST_F(HloDomainTest, MultiDomainMultiUser) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
+ %p0 = (f32[4], f32[4]) parameter(0)
+ %a = f32[4]{0} get-tuple-element(%p0), index=0
+ %domain = f32[4] domain(%a),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %b = f32[4] get-tuple-element(%p0), index=1
+ %domain.1 = f32[4] domain(%b),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
+ %domain.2 = f32[4] domain(%c),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %d = f32[4] subtract(%domain, %c),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %domain.3 = f32[4] domain(%d),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %e = f32[4] multiply(%c, %d),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
+ %domain.4 = f32[4]{0} domain(%f),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator opname_isolator(OpNameDomainCreator);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
+ opname_isolator.Run(module));
+ EXPECT_TRUE(opname_isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
+
+ HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
+ ShardingMetadata::NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
+ sharding_remover.Run(module));
+ EXPECT_TRUE(sharding_remover_changed);
+
+ HloDomainRemover opname_remover(OpNameMetadata::KindName(),
+ OpNameDomainNormalizer);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
+ opname_remover.Run(module));
+ EXPECT_TRUE(opname_remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "c"));
+}
+
+// Emulate instructions inserted at top and bottom within nested tuple domain.
+TEST_F(HloDomainTest, DomainTupleTopBottomInsert) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = f32[4] parameter(0), sharding={maximal device=1}
+ p1 = (f32[5], f32[6]) parameter(1),
+ sharding={{maximal device=1}, {maximal device=0}}
+ tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1),
+ sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}}
+ ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1,
+ sharding={{maximal device=1}, {maximal device=0}}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+
+ HloDomainIsolator isolator(ShardingDomainCreator{});
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
+ EXPECT_TRUE(isolator_changed);
+
+ // Clear sharding of tuple.0 instruction, in order to test domain sharding
+ // application.
+ auto tuple0 = FindInstruction(module, "tuple.0");
+ tuple0->clear_sharding();
+
+ // Insert the following instructons above and below tuple.0, to emulate other
+ // passes effects:
+ // COPY.0
+ // \ /
+ // TUPLE.0
+ // / \
+ // COPY.1 \
+ // / \
+ // GTE.0 GTE.1
+ // | |
+ // | COPY.2
+ // \ /
+ // \ /
+ // TUPLE.1
+ // |
+ auto tuple0_users = tuple0->users();
+ auto computation = tuple0->parent();
+ HloInstruction* copy0 = computation->AddInstruction(
+ HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy,
+ tuple0->mutable_operand(1)));
+ TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0));
+
+ HloInstruction* copy1 = computation->AddInstruction(
+ HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0));
+ HloInstruction* gte0 =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0));
+ HloInstruction* gte1 =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1));
+ HloInstruction* copy2 = computation->AddInstruction(
+ HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1));
+ HloInstruction* tuple1 =
+ computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2}));
+
+ for (HloInstruction* user : tuple0_users) {
+ TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1));
+ }
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ ShardingMetadata::NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
+ EXPECT_TRUE(remover_changed);
+
+ EXPECT_TRUE(tuple0->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ tuple0->sharding());
+
+ EXPECT_TRUE(copy0->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ copy0->sharding());
+
+ // copy1 has partial information only from gte.0, so in the end it gets no
+ // sharding at all. During propagation it does propagate the information from
+ // gte.0 though, enabling Tuple.0 to be fully sharded.
+ EXPECT_FALSE(copy1->has_sharding());
+
+ EXPECT_TRUE(gte0->has_sharding());
+ EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding());
+
+ EXPECT_TRUE(gte1->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ gte1->sharding());
+
+ EXPECT_TRUE(copy2->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ copy2->sharding());
+
+ EXPECT_TRUE(tuple1->has_sharding());
+ EXPECT_EQ(tuple0->sharding(), tuple1->sharding());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
index 751fc677e2..dc514ae3e5 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
@@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() {
TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
instruction->operand_side_metadata().Kind())
<< instruction->ToString();
- kinds.insert(instruction->user_side_metadata().Kind().ToString());
+ kinds.insert(string(instruction->user_side_metadata().Kind()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 8e53cf97f8..81d6d69a8c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
- tensorflow::StringPiece name() const override { return "domain_verifier"; }
+ absl::string_view name() const override { return "domain_verifier"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 2b109225d0..44ded2c2fa 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface {
HloElementTypeConverter(PrimitiveType eliminate_type,
PrimitiveType replace_with_type);
- tensorflow::StringPiece name() const override {
- return "element_type_converter";
- }
+ absl::string_view name() const override { return "element_type_converter"; }
// Returns the pass on the module and returns whether the module was modified.
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 36d6a2eed6..71f91fde93 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -23,13 +23,15 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -43,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -95,7 +96,7 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
@@ -125,7 +126,7 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
@@ -138,44 +139,57 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
- typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
- typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this);
- typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "U16.");
- });
- typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this);
- typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this);
- typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this);
- typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "S16.");
- });
- typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this);
- typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this);
+ typed_visitors_[PRED] =
+ absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
+ typed_visitors_[U8] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
+ typed_visitors_[U16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "U16.");
+ });
+ typed_visitors_[U32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
+ typed_visitors_[U64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
+ typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
+ typed_visitors_[S16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "S16.");
+ });
+ typed_visitors_[S32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
+ typed_visitors_[S64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
typed_visitors_[F16] =
- MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
- typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this);
- typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this);
- typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this);
+ absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
+ typed_visitors_[F32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
+ typed_visitors_[F64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
+ typed_visitors_[C64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
// elementwise computations to be done in F32 and do BF16<->F32 conversion
// around the input and the output of the computations.
typed_visitors_[BF16] =
- MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
-
- typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
- });
- typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
- });
+ absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
+
+ typed_visitors_[TUPLE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
+ });
+ typed_visitors_[OPAQUE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
+ });
}
template <typename LiteralPtr>
@@ -216,7 +230,6 @@ template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
evaluated_.clear();
arg_literals_.clear();
@@ -253,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
return tensorflow::errors::FailedPrecondition(
"Not all operands are constants.");
}
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear();
evaluated_.clear();
@@ -423,7 +435,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
if (!ShapeUtil::ElementIsFloating(operand->shape())) {
return InvalidArgument(
"expected element type in shape to be float for IsFinite op, got: %s",
- PrimitiveType_Name(operand->shape().element_type()).c_str());
+ PrimitiveType_Name(operand->shape().element_type()));
}
switch (operand->shape().element_type()) {
@@ -464,9 +476,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s",
- ShapeUtil::HumanString(compare->shape()).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str());
+ ShapeUtil::HumanString(compare->shape()),
+ ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()));
}
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
@@ -564,7 +576,8 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i);
+ bool is_output_batch_dim =
+ !absl::c_binary_search(dim_numbers.offset_dims(), i);
index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
@@ -581,10 +594,11 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
std::vector<int64> index_count(output_rank, 1);
int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i);
+ bool is_output_window_dim =
+ absl::c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
- while (c_binary_search(dim_numbers.collapsed_slice_dims(),
- slice_sizes_idx)) {
+ while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
+ slice_sizes_idx)) {
slice_sizes_idx++;
}
index_count[i] = slice_sizes[slice_sizes_idx++];
@@ -610,13 +624,13 @@ class OutputBatchIndexToInputIndex {
: dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
output_dim_is_batch_dims_.push_back(
- !c_binary_search(dim_numbers_.offset_dims(), i));
+ !absl::c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
std::distance(dim_numbers_.start_index_map().begin(),
- c_find(dim_numbers_.start_index_map(), i));
+ absl::c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
@@ -736,7 +750,7 @@ class OutputOffsetIndexToInputIndex {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.offset_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
@@ -745,7 +759,7 @@ class OutputOffsetIndexToInputIndex {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
@@ -953,7 +967,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = MakeUnique<Literal>(
+ evaluated_[get_tuple_element] = absl::make_unique<Literal>(
ShapeUtil::GetTupleElementShape(operand->shape(), index));
return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
/*dest_shape_index=*/{},
@@ -1091,8 +1105,8 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloEvaluator loop_body_evaluator(max_loop_iterations_);
while (keep_going) {
if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
- return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).",
- while_hlo->name().c_str(), max_loop_iterations_);
+ return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
+ while_hlo->name(), max_loop_iterations_);
}
TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
*cond_comp, {lcv.get()}));
@@ -1155,10 +1169,11 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
result_keys_literal->PopulateR1(
tensorflow::gtl::ArraySlice<KeyType>(result_keys));
- auto result_values_literal = MakeUnique<Literal>(values_literal.shape());
+ auto result_values_literal =
+ absl::make_unique<Literal>(values_literal.shape());
result_values_literal->PopulateR1(
tensorflow::gtl::ArraySlice<ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
@@ -1173,8 +1188,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = MakeUnique<Literal>(keys_literal.shape());
- auto values_result_literal = MakeUnique<Literal>(values_literal.shape());
+ auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ auto values_result_literal =
+ absl::make_unique<Literal>(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
@@ -1246,7 +1262,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
if (sort_dim != rank - 1) {
return Unimplemented(
- "Trying to support along dimension %lld, which is not the last "
+ "Trying to support along dimension %d, which is not the last "
"dimension",
sort_dim);
}
@@ -1267,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
- return Status::OK();
+ return ShapeUtil::ValidateShape(hlo->shape());
}
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index a4c37ef328..0ea7089552 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -222,11 +222,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(operand->shape()).c_str());
+ ShapeUtil::HumanString(shape),
+ ShapeUtil::HumanString(operand->shape()));
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 1394be68e4..c3af15c6a8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -51,8 +52,11 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest() : use_bfloat16_(GetParam()) {
- evaluator_ = MakeUnique<HloEvaluator>();
+ HloEvaluatorTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ use_bfloat16_(GetParam()) {
+ evaluator_ = absl::make_unique<HloEvaluator>();
}
std::unique_ptr<Literal> Evaluate(
@@ -523,7 +527,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
(*expected_array)(1, 0, 0, 0) = 1.0f;
(*expected_array)(1, 2, 0, 0) = 2.0f;
@@ -547,7 +551,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -568,7 +572,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
std::unique_ptr<Literal> result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
- auto expected_array = MakeUnique<Array2D<float>>(1, 5);
+ auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
(*expected_array)(0, 0) = 7.0f;
(*expected_array)(0, 1) = 2.718f;
(*expected_array)(0, 2) = 2.718f;
@@ -588,7 +592,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -612,7 +616,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected_array = MakeUnique<Array2D<float>>(0, 9);
+ auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -628,7 +632,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// { 3 },
// { 4 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 1);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -679,7 +683,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -710,7 +714,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 3);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -722,7 +726,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -1215,7 +1219,12 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
+ public:
+ HloEvaluatorPreciseReduceTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
@@ -1297,7 +1306,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1339,7 +1348,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1390,7 +1399,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1511,7 +1520,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
// { 9, 10, 11, 12, 13 },
// { 17, 18, 19, 20, 21 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(3, 5);
+ auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1544,7 +1553,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1580,7 +1589,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1614,7 +1623,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1651,7 +1660,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal2 =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1687,7 +1696,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 7fdf4521de..f682e69ee9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,11 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -105,7 +110,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
double GetAsDouble(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> input_index) {
- CHECK(false);
+ LOG(FATAL) << "Trying to get complex literal as double: "
+ << literal.ToString();
}
public:
@@ -139,7 +145,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
- HloOpcodeString(hlo_instruction->opcode()).c_str());
+ HloOpcodeString(hlo_instruction->opcode()));
}
// TODO(b/35950897): many of the stl functions used in the handlers are not
@@ -547,7 +553,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleDivide(HloInstruction* divide) override {
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_floating_point<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleDivide(HloInstruction* divide) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
ElementwiseT rhs_elem) {
@@ -557,6 +567,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT,
+ typename std::enable_if<std::is_signed<NativeT>::value &&
+ std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleDivide(HloInstruction* divide) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[divide],
+ ElementWiseBinaryOp(
+ divide,
+ [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT {
+ if (rhs_elem == 0) {
+ return static_cast<ElementwiseT>(-1);
+ }
+ if (rhs_elem == -1 &&
+ lhs_elem == std::numeric_limits<ElementwiseT>::min()) {
+ return lhs_elem;
+ }
+ return lhs_elem / rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
+ nullptr>
+ Status HandleDivide(HloInstruction* divide) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
+ ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return rhs_elem == 0
+ ? std::numeric_limits<ElementwiseT>::max()
+ : (lhs_elem / rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ Status HandleDivide(HloInstruction* divide) {
+ return HandleDivide<ElementwiseT>(divide);
+ }
+
+ template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr>
Status HandleMaximum(HloInstruction* maximum) {
@@ -642,9 +692,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
@@ -654,6 +703,40 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ template <typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
+ nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
+ ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_signed<NativeT>::value &&
+ std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[remainder],
+ ElementWiseBinaryOp(
+ remainder,
+ [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT {
+ if (rhs_el == 0) {
+ return lhs_el;
+ }
+ if (rhs_el == -1 &&
+ lhs_el == std::numeric_limits<ElementwiseT>::min()) {
+ return 0;
+ }
+ return lhs_el % rhs_el;
+ }));
+ return Status::OK();
+ }
+
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
@@ -895,7 +978,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
@@ -1052,7 +1135,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
@@ -1100,7 +1183,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// result_index_locations[i] contains one or two pointers to the locations
// in lhs_index or rhs_index where the i'th result index should go.
- tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
+ absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
result_index_locations;
result_index_locations.reserve(lhs_rank + rhs_rank - 2);
@@ -1126,7 +1209,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = MakeUnique<Literal>(dot->shape());
+ auto result = absl::make_unique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
@@ -1175,7 +1258,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = MakeUnique<Literal>(pad->shape());
+ auto result = absl::make_unique<Literal>(pad->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
return scalar;
@@ -1340,7 +1423,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = MakeUnique<Literal>(map->shape());
+ auto result = absl::make_unique<Literal>(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
@@ -1454,7 +1537,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
result_literal->PopulateR1(
tensorflow::gtl::ArraySlice<ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
@@ -1466,7 +1549,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
@@ -1540,11 +1623,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce->shape());
+ auto result = absl::make_unique<Literal>(reduce->shape());
+ Status eval_status;
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
ReturnT result_val = init_scalar;
+ if (!eval_status.ok()) {
+ return result_val;
+ }
std::vector<int64> base(arg_dimensions.size());
for (int64 i = 0; i < multi_index.size(); ++i) {
@@ -1565,7 +1652,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_dim_steps, func);
return static_cast<ReturnT>(computed_result);
}
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
+ -> StatusOr<bool> {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
@@ -1573,12 +1661,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
- .ConsumeValueOrDie();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ embedded_evaluator.Evaluate<const Literal*>(
+ *function, {result_val_literal.get(),
+ curr_val_literal.get()}));
// Clear visit states so that we can use the evaluator again on
// the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1588,13 +1674,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
};
// Computes one element of the result, reducing all dimensions that
// contribute to that element.
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
+ eval_status = ShapeUtil::ForEachIndexWithStatus(
+ arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
return result_val;
}));
parent_->evaluated_[reduce] = std::move(result);
- return Status::OK();
+ return eval_status;
}
bool IsScalarAdd(HloComputation* computation) {
@@ -1621,7 +1707,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = MakeUnique<Literal>(select_and_scatter->shape());
+ auto result = absl::make_unique<Literal>(select_and_scatter->shape());
// Initialize result array with the init value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
@@ -1665,8 +1751,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// 2. Using the selected index, scatter value from `source` to result. We
// do this by iterating through the window, and compare each index with
// the selected index.
- tensorflow::gtl::optional<ReturnT> selected_val;
- tensorflow::gtl::optional<std::vector<int64>> selected_index;
+ absl::optional<ReturnT> selected_val;
+ absl::optional<std::vector<int64>> selected_index;
IterateThroughWindow(
window_shape, window, operand_literal.shape(), source_index,
@@ -1757,7 +1843,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce_window->shape());
+ auto result = absl::make_unique<Literal>(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> output_index) {
@@ -1824,7 +1910,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_scatter_dim =
- !c_binary_search(dim_numbers.update_window_dims(), i);
+ !absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_scatter_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1843,7 +1929,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_window_dim =
- c_binary_search(dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1870,7 +1956,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
: dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) {
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
update_dim_is_scatter_dims_.push_back(
- !c_binary_search(dim_numbers_.update_window_dims(), i));
+ !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
@@ -2000,7 +2086,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> window_index_to_update_index;
int64 update_index_count = 0;
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.update_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
window_index_to_update_index.push_back(update_index_count++);
} else {
update_index_count++;
@@ -2009,7 +2095,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.inserted_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
input_dim_value_to_update_index_.push_back(-1);
} else {
input_dim_value_to_update_index_.push_back(
@@ -2409,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_same<NativeT, float>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint32>::value>::type* = nullptr>
- Status HandleIota(HloInstruction* iota) {
- auto result = MakeUnique<Literal>(iota->shape());
- auto data = result->data<ReturnT>();
+ Status HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension()));
std::iota(data.begin(), data.end(), 0);
- parent_->evaluated_[iota] = std::move(result);
+ auto result = LiteralUtil::CreateR1<NativeT>(data);
+
+ if (ShapeUtil::Rank(iota->shape()) > 1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[iota],
+ result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ } else {
+ TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
+ parent_->evaluated_[iota] = std::move(result);
+ }
+
return Status::OK();
}
template <typename NativeT,
@@ -2492,7 +2588,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
@@ -2570,15 +2666,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str());
+ ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()));
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
@@ -2606,17 +2701,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str(),
- ShapeUtil::HumanString(ehs->shape()).c_str());
+ ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()),
+ ShapeUtil::HumanString(ehs->shape()));
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index c3ccbf0f0c..de3d7a1677 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
@@ -49,7 +51,7 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
size_t profile_counters_size = hlo_profile_index_map.total_count();
std::unique_ptr<HloProfilePrinterData> profile_printer_data =
- MakeUnique<HloProfilePrinterData>();
+ absl::make_unique<HloProfilePrinterData>();
profile_printer_data->set_profile_counters_size(profile_counters_size);
profile_printer_data->mutable_computation_infos()->Reserve(
hlo_profile_index_map.computation_count());
@@ -67,11 +69,11 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
// The profile indices were computed deterministically in
// HloProfileIndexMap::HloProfileIndexMap.
- c_sort(computation_and_profile_idx_list,
- [](const std::pair<const HloComputation*, int64>& left,
- const std::pair<const HloComputation*, int64>& right) {
- return left.second < right.second;
- });
+ absl::c_sort(computation_and_profile_idx_list,
+ [](const std::pair<const HloComputation*, int64>& left,
+ const std::pair<const HloComputation*, int64>& right) {
+ return left.second < right.second;
+ });
for (const auto& pair : computation_and_profile_idx_list) {
CHECK_LT(pair.second, profile_counters_size);
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index eba80c0f19..460ae2b5ec 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::AllOf;
using ::testing::ContainsRegex;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 1efa6eb5bd..3041d94fa9 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -26,6 +26,12 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -37,50 +43,25 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
-using ::tensorflow::Env;
-using ::tensorflow::WriteStringToFile;
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
-using ::tensorflow::io::JoinPath;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::StringReplace;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
namespace hlo_graph_dumper {
namespace {
-// Helpers for Printf and Appendf.
-template <typename T>
-struct PrintfConvert {
- const T& operator()(const T& t) const { return t; }
-};
-template <>
-struct PrintfConvert<string> {
- const char* operator()(const string& s) const { return s.c_str(); }
-};
-
-// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str()
-// on strings.
-template <typename... Ts>
-string Printf(const char* fmt, const Ts&... ts) {
- return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...);
-}
-template <typename... Ts>
-void Appendf(string* s, const char* fmt, const Ts&... ts) {
- tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...);
-}
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrJoin;
+using tensorflow::Env;
+using tensorflow::WriteStringToFile;
+using tensorflow::io::JoinPath;
// Used to indicate how we should treat a given HLOInstruction in the graph.
// should we treat it like normal, hide it, and so on?
@@ -209,17 +190,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) {
string NodeColorAttributes(ColorScheme color) {
NodeColors node_colors = NodeColorsForScheme(color);
- return Printf(
- R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
- node_colors.style, node_colors.font_color, node_colors.stroke_color,
- node_colors.fill_color);
+ return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
+ node_colors.style, node_colors.font_color,
+ node_colors.stroke_color, node_colors.fill_color);
}
// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
// graphviz HTML-like string.
-string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
- return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
- "&gt;", /*replace_all=*/true);
+string HtmlLikeStringSanitize(absl::string_view s) {
+ return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
}
// Tries to generates a human-readable one-word description of the given
@@ -322,11 +301,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
class HloDotDumper {
public:
- HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
+ HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options, bool show_backend_config,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
- label_(std::string(label)),
+ label_(label),
debug_options_(debug_options),
show_backend_config_(show_backend_config),
profile_(profile),
@@ -448,7 +427,7 @@ string HloDotDumper::Dump() {
}
string HloDotDumper::Header() {
- const char* fmt = R"(digraph G {
+ constexpr char fmt[] = R"(digraph G {
rankdir = TB;
compound = true;
label = <<b>%s</b>>;
@@ -457,7 +436,7 @@ labelloc = t;
tooltip = " ";
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
-stylesheet="
+stylesheet=<
data:text/css,
@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
svg text {
@@ -466,7 +445,7 @@ stylesheet="
}
%s
-"
+>
)";
@@ -481,8 +460,8 @@ stylesheet="
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
- Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
- tensorflow::strings::HumanReadableNum(cycles));
+ absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
+ tensorflow::strings::HumanReadableNum(cycles));
}
// Create CSS rules that say, when you hover over the given node or cluster,
@@ -509,14 +488,14 @@ stylesheet="
// One could imagine other ways of writing this CSS rule that involve
// less duplication, but this way seems to be relatively performant.
edge_css_rules.push_back(
- Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n"
- " #%s%d:hover ~ #edge%lld path { "
- "stroke: %s; stroke-width: .2em; }\n"
- " #%s%d:hover ~ #edge%lld polygon { "
- "fill: %s; stroke: %s; stroke-width: .2em; }\n",
- elem_type, elem_id, edge_id, color, //
- elem_type, elem_id, edge_id, color, //
- elem_type, elem_id, edge_id, color, color));
+ StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
+ " #%s%d:hover ~ #edge%d path { "
+ "stroke: %s; stroke-width: .2em; }\n"
+ " #%s%d:hover ~ #edge%d polygon { "
+ "fill: %s; stroke: %s; stroke-width: .2em; }\n",
+ elem_type, elem_id, edge_id, color, //
+ elem_type, elem_id, edge_id, color, //
+ elem_type, elem_id, edge_id, color, color));
};
// The "to_node" value may be a NULL, indicating that this points to the
@@ -559,10 +538,10 @@ stylesheet="
}
}
- return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
+ return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n"));
}
-string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
+string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
@@ -600,9 +579,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
<< " as " << next_edge_id_;
edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
- const char* edge_fmt =
+ constexpr char edge_fmt[] =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
- edges_.push_back(Printf(
+ edges_.push_back(StrFormat(
edge_fmt, InstructionId(from), InstructionId(parent_instr),
SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
}
@@ -619,9 +598,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
string subcomp_label, style;
if (parent_instr->opcode() == HloOpcode::kFusion) {
- subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s",
- HtmlLikeStringSanitize(parent_instr->name()),
- HtmlLikeStringSanitize(parent_instr->ToCategory()));
+ subcomp_label =
+ StrFormat("Fused expression for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(parent_instr->ToCategory()));
string extra_info = GetInstructionNodeExtraInfo(parent_instr);
if (!extra_info.empty()) {
StrAppend(&subcomp_label, "<br/>", extra_info);
@@ -647,18 +627,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
}
style =
- Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
- fillcolor, strokecolor);
+ StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
+ fillcolor, strokecolor);
} else {
- subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s",
- HtmlLikeStringSanitize(parent_instr->name()),
- HtmlLikeStringSanitize(subcomp->name()));
+ subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(subcomp->name()));
style = "style=rounded; color=black;";
}
string comp_body = DumpComputation(subcomp);
- const char* computation_fmt = R"(subgraph %s {
+ constexpr char computation_fmt[] = R"(subgraph %s {
%s
label = <%s>;
labelloc = t;
@@ -667,7 +647,7 @@ tooltip = " ";
} // %s
)";
- return Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
+ return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
}
string HloDotDumper::DumpComputation(const HloComputation* comp) {
@@ -718,11 +698,11 @@ string HloDotDumper::DumpRootTag() {
VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
<< next_edge_id_;
edge_ids_.insert({{from, to}, next_edge_id_++});
- edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
+ edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
- return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
- "\n",
- to_id, node_body, node_shape, NodeColorAttributes(color));
+ return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
+ "\n",
+ to_id, node_body, node_shape, NodeColorAttributes(color));
}
static const HloConstantInstruction* TryGetFusionParameterConstant(
@@ -817,10 +797,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
}
- return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
- "\n",
- InstructionId(instr), node_body, node_shape, node_metadata,
- NodeColorAttributes(color));
+ return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
+ "\n",
+ InstructionId(instr), node_body, node_shape, node_metadata,
+ NodeColorAttributes(color));
}
string HloDotDumper::GetInstructionNodeInlinedOperands(
@@ -833,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
if (ShapeUtil::IsZeroElementArray(shape)) {
- return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
@@ -848,19 +828,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// collected from profiling tools. Those constants may not have a valid
// literal.
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
- return Printf("%s (%s)", constant->literal().ToString(),
- ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("%s (%s)", constant->literal().ToString(),
+ ShapeUtil::HumanString(constant->shape()));
}
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
- if (tensorflow::str_util::StartsWith(constant->name(), "constant")) {
+ if (absl::StartsWith(constant->name(), "constant")) {
constant_name = constant->name();
} else {
constant_name = StrCat("constant ", constant->name());
}
- return Printf("%s %s", constant_name,
- ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("%s %s", constant_name,
+ ShapeUtil::HumanString(constant->shape()));
};
std::vector<string> lines;
@@ -881,7 +861,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
TryGetFusionParameterConstant(operand)) {
operand_str = stringify_constant(constant);
} else {
- operand_str = Printf("Parameter %lld", operand->parameter_number());
+ operand_str = StrFormat("Parameter %d", operand->parameter_number());
}
} else {
operand_str = operand->name();
@@ -890,13 +870,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
if (operand_str) {
if (instr->operand_count() > 1) {
- lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str));
+ lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
} else {
- lines.push_back(Printf("<b>operand</b> = %s", *operand_str));
+ lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
}
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
@@ -1049,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kGray;
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kRecv:
@@ -1059,7 +1040,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
- case HloOpcode::kHostCompute:
case HloOpcode::kWhile:
return kDarkGreen;
case HloOpcode::kConstant:
@@ -1080,14 +1060,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
// If we have a parameter, put the param number in the name.
if (instr->opcode() == HloOpcode::kParameter) {
- return Printf("<b>Parameter %lld</b>", instr->parameter_number());
+ return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
}
// The HLO instruction name contains usually the opcode, e.g. "%add.42" is
// an add instruction. In this case we render just the name.
- if (tensorflow::str_util::StartsWith(instr->name(),
- HloOpcodeString(instr->opcode()))) {
- return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
+ if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
+ return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
string extended_opcode =
StrCat(HloOpcodeString(instr->opcode()),
@@ -1095,8 +1074,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
? ""
: StrCat(":", xla::ToString(instr->fusion_kind())));
// If the name does not contain the opcode, render both.
- return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
- HtmlLikeStringSanitize(instr->name()));
+ return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
+ HtmlLikeStringSanitize(instr->name()));
}
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
@@ -1105,16 +1084,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
}
if (!instr->metadata().op_type().empty()) {
- lines.push_back(Printf(
+ lines.push_back(StrFormat(
"op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
}
if (!instr->metadata().source_file().empty() &&
instr->metadata().source_line() != 0) {
- lines.push_back(Printf("op_type: %s", instr->metadata().source_file(),
- instr->metadata().source_line()));
+ lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
+ instr->metadata().source_line()));
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
@@ -1161,13 +1140,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
constexpr int kMaxShapeLen = 64;
if (instr_shape.length() > kMaxShapeLen) {
instr_shape = StrCat(
- tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
- "...");
+ absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
}
lines.push_back(instr_shape);
}
if (debug_options_.xla_hlo_graph_addresses()) {
- lines.push_back(Printf("[%p]", instr));
+ lines.push_back(StrFormat("[%p]", instr));
}
if (profile_ != nullptr) {
double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
@@ -1175,11 +1153,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
profile_->total_cycles_executed(*instr->parent());
if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
lines.push_back(
- Printf("%% of cycles executed=%.2f",
- 100 * hlo_cycles_executed / total_cycles_executed));
+ StrFormat("%% of cycles executed=%.2f",
+ 100 * hlo_cycles_executed / total_cycles_executed));
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
// Gets the total number of array elements in the given shape. For tuples, this
@@ -1211,7 +1189,8 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
string edge_label;
if (instr->operand_count() > 1 && !control_edge) {
- edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num);
+ edge_label =
+ StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
} else if (control_edge) {
edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
}
@@ -1221,10 +1200,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
// means.
bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
- const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
- edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
- (is_big_array ? "normal" : "empty"), from->name(),
- to->name(), edge_label));
+ constexpr char kEdgeFmt[] =
+ R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
+ edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
+ (is_big_array ? "normal" : "empty"),
+ from->name(), to->name(), edge_label));
};
// Add edges from instr's operands to instr. Parameters within fusion
@@ -1265,14 +1245,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
continue;
}
if (instr->called_computations().size() == 1) {
- lines.push_back(Printf("Subcomputation: <b>%s</b>",
- HtmlLikeStringSanitize(*computation_type)));
+ lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
+ HtmlLikeStringSanitize(*computation_type)));
} else {
- lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i,
- HtmlLikeStringSanitize(*computation_type)));
+ lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
+ HtmlLikeStringSanitize(*computation_type)));
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
const HloInstruction* HloDotDumper::GetNodeForEdge(
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 1d7a062c55..064c53252c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,12 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::HasSubstr;
string TestName() {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 57e75cf931..ed4e159910 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -21,10 +21,17 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -39,17 +46,15 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
@@ -224,7 +229,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
} else {
- instruction = MakeUnique<HloConstantInstruction>(proto.shape());
+ instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
}
break;
}
@@ -294,15 +299,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "CrossReplicaSum should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- tensorflow::gtl::optional<int64> all_reduce_id;
+ absl::optional<int64> all_reduce_id;
if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id();
}
instruction = CreateCrossReplicaSum(
proto.shape(), all_operands(), computations(0),
- /*replica_group_ids=*/
- std::vector<int64>(proto.replica_group_ids().begin(),
- proto.replica_group_ids().end()),
+ /*replica_groups=*/
+ std::vector<ReplicaGroup>(proto.replica_groups().begin(),
+ proto.replica_groups().end()),
/*barrier=*/proto.cross_replica_sum_barrier(),
/*all_reduce_id=*/all_reduce_id);
break;
@@ -312,8 +317,18 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.shape(), all_operands(),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
- proto.replica_groups().end()),
- /*barrier=*/proto.cross_replica_sum_barrier());
+ proto.replica_groups().end()));
+ break;
+ }
+ case HloOpcode::kCollectivePermute: {
+ std::vector<std::pair<int64, int64>> source_target_pairs(
+ proto.source_target_pairs_size());
+ for (int i = 0; i < source_target_pairs.size(); i++) {
+ source_target_pairs[i].first = proto.source_target_pairs(i).source();
+ source_target_pairs[i].second = proto.source_target_pairs(i).target();
+ }
+ instruction = CreateCollectivePermute(proto.shape(), operands(0),
+ source_target_pairs);
break;
}
case HloOpcode::kConvolution:
@@ -361,11 +376,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.convolution_dimension_numbers());
}
break;
- case HloOpcode::kHostCompute:
- instruction =
- CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(),
- proto.cost_estimate_ns());
- break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Pad instruction should have 2 operands but sees "
@@ -379,7 +389,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "DynamicSlice instruction should have 2 operands but sees "
<< proto.operand_ids_size();
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
- c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
+ absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
slice_sizes);
break;
@@ -391,7 +401,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
- MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
+ absl::make_unique<GatherDimensionNumbers>(
+ proto.gather_dimension_numbers());
std::vector<int64> gather_slice_sizes;
for (int64 bound : proto.gather_slice_sizes()) {
gather_slice_sizes.push_back(bound);
@@ -409,15 +420,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
- proto.scatter_dimension_numbers());
+ auto scatter_dimension_numbers =
+ absl::make_unique<ScatterDimensionNumbers>(
+ proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
+ case HloOpcode::kIota:
+ TF_RET_CHECK(proto.dimensions_size() <= 1)
+ << "Iota instruction should have at most 1 dimension but sees "
+ << proto.dimensions_size();
+ instruction = CreateIota(proto.shape(), proto.dimensions(0));
+ break;
default: {
- instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
+ instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
<< "No instruction with id " << operand_id;
@@ -445,10 +463,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->precision_config_ = proto.precision_config();
if (proto.has_dot_dimension_numbers()) {
instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
+ absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
}
if (proto.has_sharding()) {
@@ -462,34 +481,36 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
- return MakeUnique<HloParameterInstruction>(parameter_number, shape, name);
+ return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
+ name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
const string& tag, HloInstruction* operand) {
- return MakeUnique<HloTraceInstruction>(tag, operand);
+ return absl::make_unique<HloTraceInstruction>(tag, operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
std::unique_ptr<Literal> literal) {
- return MakeUnique<HloConstantInstruction>(std::move(literal));
+ return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
- const Shape& shape) {
- return WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
+ const Shape& shape, int64 iota_dimension) {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
- return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index);
+ return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
+ index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
- return MakeUnique<HloRngInstruction>(shape, distribution, parameters);
+ return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
@@ -499,7 +520,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
}
- auto instruction = WrapUnique(new HloInstruction(opcode, shape));
+ auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -604,31 +625,33 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation) {
- return MakeUnique<HloMapInstruction>(shape, operands, map_computation);
+ return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count) {
- return MakeUnique<HloConvolutionInstruction>(
+ return absl::make_unique<HloConvolutionInstruction>(
shape, lhs, rhs, window, dimension_numbers, feature_group_count);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length) {
- return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length);
+ return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
+ fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(dimension_numbers);
+ absl::make_unique<DotDimensionNumbers>(dimension_numbers);
return instruction;
}
@@ -637,10 +660,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
+ instruction->dot_dimension_numbers_ =
+ absl::make_unique<DotDimensionNumbers>();
instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
return instruction;
@@ -651,7 +676,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, operand, exponent_bits, mantissa_bits);
}
@@ -659,40 +684,47 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction::CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id) {
- return MakeUnique<HloAllReduceInstruction>(
- shape, operands, reduce_computation, replica_group_ids, barrier,
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
+ const absl::optional<int64>& all_reduce_id) {
+ return absl::make_unique<HloAllReduceInstruction>(
+ shape, operands, reduce_computation, replica_groups, barrier,
all_reduce_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier) {
- return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups,
- barrier);
+ const std::vector<ReplicaGroup>& replica_groups) {
+ return absl::make_unique<HloAllToAllInstruction>(shape, operands,
+ replica_groups);
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, operand, source_target_pairs);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
- return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config);
+ return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
+ config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
- token_operand, outfeed_config);
+ HloInstruction* token_operand, absl::string_view outfeed_config) {
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape, operand, token_operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloSendInstruction>(operand, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
@@ -700,14 +732,15 @@ HloInstruction::CreateCrossReplicaSum(
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer);
+ return absl::make_unique<HloSendDoneInstruction>(send_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloRecvInstruction>(shape, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
@@ -715,19 +748,20 @@ HloInstruction::CreateCrossReplicaSum(
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer);
+ return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
+ return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK(!operands.empty());
- auto instruction = WrapUnique(
+ auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
@@ -736,14 +770,15 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
- return WrapUnique(
+ return absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
// Body comes before condition computation in the vector.
instruction->called_computations_.push_back(body);
@@ -756,7 +791,7 @@ HloInstruction::CreateCrossReplicaSum(
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
@@ -773,15 +808,15 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return MakeUnique<HloSliceInstruction>(shape, operand, start_indices,
- limit_indices, strides);
+ return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
+ limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return MakeUnique<HloDynamicSliceInstruction>(shape, operand, start_indices,
- slice_sizes);
+ return absl::make_unique<HloDynamicSliceInstruction>(
+ shape, operand, start_indices, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -789,8 +824,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
+ auto instruction = absl::WrapUnique(
+ new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(update);
instruction->AppendOperand(start_indices);
@@ -800,12 +835,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
int64 dimension) {
- return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension);
+ return absl::make_unique<HloConcatenateInstruction>(shape, operands,
+ dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
const Shape& shape, HloInstruction* operand) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -814,7 +851,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -823,7 +860,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- auto instruction = WrapUnique(new HloReduceInstruction(
+ auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
@@ -837,15 +874,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
- return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
- reduce_computation);
+ return absl::make_unique<HloReduceInstruction>(
+ shape, all_args, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
- return MakeUnique<HloReduceWindowInstruction>(shape, operand, init_value,
- window, reduce_computation);
+ return absl::make_unique<HloReduceWindowInstruction>(
+ shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -854,7 +891,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, operand, scale, offset, epsilon, feature_index);
}
@@ -863,7 +900,7 @@ HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
@@ -873,9 +910,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean,
- variance, grad_output, epsilon,
- feature_index);
+ return absl::make_unique<HloBatchNormGradInstruction>(
+ shape, operand, scale, mean, variance, grad_output, epsilon,
+ feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -883,15 +920,15 @@ HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return MakeUnique<HloBroadcastInstruction>(shape, operand,
- broadcast_dimensions);
+ return absl::make_unique<HloBroadcastInstruction>(shape, operand,
+ broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -949,8 +986,8 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
- return MakeUnique<HloPadInstruction>(shape, operand, padding_value,
- padding_config);
+ return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
+ padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
@@ -959,7 +996,8 @@ HloInstruction::CreateBroadcastSequence(
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -967,26 +1005,27 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
+ return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
HloInstruction* values) {
- return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
+ fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* fusion_computation) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, operands,
- fusion_computation);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
+ fusion_computation);
}
void HloInstruction::set_single_sharding(const HloSharding& sharding) {
@@ -1006,6 +1045,7 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
+ derived_instruction->set_precision_config(precision_config_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1018,7 +1058,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
- case HloOpcode::kHostCompute:
return true;
case HloOpcode::kCrossReplicaSum:
return all_reduce_id().has_value();
@@ -1044,7 +1083,7 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -1054,16 +1093,9 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target) {
- return MakeUnique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
- return MakeUnique<HloHostComputeInstruction>(shape, operands, channel_name,
- cost_estimate_ns);
+ absl::string_view custom_call_target) {
+ return absl::make_unique<HloCustomCallInstruction>(shape, operands,
+ custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -1080,8 +1112,8 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return MakeUnique<HloGatherInstruction>(shape, operand, start_indices,
- gather_dim_numbers, slice_sizes);
+ return absl::make_unique<HloGatherInstruction>(
+ shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
@@ -1089,16 +1121,17 @@ bool HloInstruction::HasSideEffect() const {
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers) {
- return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
- updates, update_computation,
- scatter_dim_numbers);
+ return absl::make_unique<HloScatterInstruction>(
+ shape, operand, scatter_indices, updates, update_computation,
+ scatter_dim_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
instruction->operand_side_metadata_ = std::move(operand_side_metadata);
instruction->user_side_metadata_ = std::move(user_side_metadata);
instruction->AppendOperand(operand);
@@ -1146,13 +1179,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReducePrecision:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
@@ -1274,6 +1307,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
break;
}
+ // SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_raw_backend_config_string(backend_config_);
@@ -1339,7 +1373,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
// If names ends with .suffix[0-9]+ then replace with a suffix with the
// numeric value incremented.
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
clone->name_ =
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
} else {
@@ -1614,11 +1648,11 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
@@ -1812,7 +1846,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) {
string HloInstruction::SignatureString() const {
string operands =
- Join(operands_, ", ", [](string* out, HloInstruction* operand) {
+ StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
@@ -1832,7 +1866,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
}
bool HloInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
switch (opcode_) {
// Unary elementwise operations.
case HloOpcode::kAbs:
@@ -1959,7 +1993,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
- operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
// If operand is already been deleted, put `null` to the string output.
if (operand == nullptr) {
StrAppend(out, "null ");
@@ -1979,7 +2013,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
} else if (!options.compact_operands()) {
str.push_back(PrintName(operand->name(), options));
}
- StrAppend(out, Join(str, " "));
+ StrAppend(out, StrJoin(str, " "));
});
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
@@ -1996,6 +2030,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(DotDimensionNumbersToString());
}
+ string precision_config_string = PrecisionConfigToString();
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2021,11 +2060,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
- "calls=", Join(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out,
- PrintName(computation->name(), options));
- })));
+ "calls=",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, PrintName(computation->name(), options));
+ })));
}
} else if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
@@ -2058,12 +2097,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
break;
default:
if (!called_computations().empty()) {
- extra.push_back(
- StrCat("calls=\n",
- Join(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out, computation->ToString(new_options));
- })));
+ extra.push_back(StrCat(
+ "calls=\n",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->ToString(new_options));
+ })));
}
break;
}
@@ -2074,11 +2113,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
if (!control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
- Join(control_predecessors_, ", ",
- [&](string* out, HloInstruction* pre) {
- StrAppend(out,
- PrintName(pre->name(), options));
- }),
+ StrJoin(control_predecessors_, ", ",
+ [&](string* out, HloInstruction* pre) {
+ StrAppend(out,
+ PrintName(pre->name(), options));
+ }),
"}"));
}
if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
@@ -2092,10 +2131,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
- Join(operands_, ", ",
- [](string* out, HloInstruction* operand) {
- StrAppend(out, "%", operand->name());
- }),
+ StrJoin(operands_, ", ",
+ [](string* out, HloInstruction* operand) {
+ StrAppend(out, "%", operand->name());
+ }),
")");
}
@@ -2117,6 +2156,7 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
+ *proto.mutable_precision_config() = precision_config_;
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
@@ -2155,7 +2195,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
-bool HloInstruction::IsFusable() const {
+bool HloInstruction::IsFusible() const {
// Instructions which are traced should not be fused.
if (tracing()) {
return false;
@@ -2261,6 +2301,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
+ case HloOpcode::kCollectivePermute:
+ return visitor->HandleCollectivePermute(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
@@ -2329,8 +2371,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
- case HloOpcode::kHostCompute:
- return visitor->HandleHostCompute(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kWhile:
@@ -2369,15 +2409,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return InternalError(
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
"please file a bug for XLA.",
- HloOpcodeString(opcode_).c_str());
+ HloOpcodeString(opcode_));
}
// Explicit instantiations.
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
-using DFSStack =
- tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
+using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
@@ -2453,7 +2492,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
- current_node->ToString().c_str());
+ current_node->ToString());
}
}
@@ -2462,7 +2501,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
- current_node->ToString().c_str());
+ current_node->ToString());
}
}
}
@@ -2622,7 +2661,7 @@ bool HloInstruction::IsElementwiseBinary() const {
}
bool HloInstruction::IsElementwise() const {
- return IsElementwiseImpl(tensorflow::gtl::nullopt);
+ return IsElementwiseImpl(absl::nullopt);
}
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
@@ -2778,7 +2817,7 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
if (kind_name == "kCustom") {
return HloInstruction::FusionKind::kCustom;
}
- return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
+ return InvalidArgument("Unknown fusion kind: %s", kind_name);
}
string PaddingConfigToString(const PaddingConfig& padding) {
@@ -2787,7 +2826,7 @@ string PaddingConfigToString(const PaddingConfig& padding) {
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
- return Join(
+ return StrJoin(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
@@ -2811,11 +2850,15 @@ string OpMetadataToString(const OpMetadata& metadata) {
if (metadata.source_line() != 0) {
result.push_back(StrCat("source_line=", metadata.source_line()));
}
- return Join(result, " ");
+ return StrJoin(result, " ");
}
string RandomDistributionToString(const RandomDistribution& distribution) {
- return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
+ return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
+}
+
+string PrecisionToString(const PrecisionConfigProto::Precision& precision) {
+ return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision));
}
string ConvolutionDimensionNumbersToString(
@@ -2843,8 +2886,8 @@ string ConvolutionDimensionNumbersToString(
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
- return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
- Join(output_dims, ""));
+ return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
+ StrJoin(output_dims, ""));
}
string HloInstruction::DotDimensionNumbersToString() const {
@@ -2855,19 +2898,21 @@ string HloInstruction::DotDimensionNumbersToString() const {
const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
if (!dnums.lhs_batch_dimensions().empty()) {
result.push_back(StrCat("lhs_batch_dims={",
- Join(dnums.lhs_batch_dimensions(), ","), "}"));
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("lhs_contracting_dims={",
- Join(dnums.lhs_contracting_dimensions(), ","), "}"));
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
if (!dnums.rhs_batch_dimensions().empty()) {
result.push_back(StrCat("rhs_batch_dims={",
- Join(dnums.rhs_batch_dimensions(), ","), "}"));
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("rhs_contracting_dims={",
- Join(dnums.rhs_contracting_dimensions(), ","), "}"));
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
- return Join(result, ", ");
+ return StrJoin(result, ", ");
}
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
@@ -2881,7 +2926,44 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
return map;
}();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
+ auto found = map->find(absl::AsciiStrToLower(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
+string HloInstruction::PrecisionConfigToString() const {
+ if (precision_config_.operand_precision().empty()) {
+ return "";
+ }
+ return StrCat(
+ "operand_precision={",
+ StrJoin(precision_config_.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfigProto::Precision_IsValid(precision))
+ << precision;
+ StrAppend(out, PrecisionToString(
+ static_cast<PrecisionConfigProto::Precision>(
+ precision)));
+ }),
+ "}");
+}
+
+StatusOr<PrecisionConfigProto::Precision> StringToPrecision(
+ const string& name) {
+ static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] {
+ static auto* map =
+ new std::unordered_map<string, PrecisionConfigProto::Precision>;
+ for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) {
+ if (PrecisionConfigProto::Precision_IsValid(i)) {
+ auto value = static_cast<PrecisionConfigProto::Precision>(i);
+ (*map)[PrecisionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
@@ -3131,31 +3213,25 @@ const string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
}
-const std::vector<int64>& HloInstruction::replica_group_ids() const {
- return Cast<HloAllReduceInstruction>(this)->replica_group_ids();
+const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
+ return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
-const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
- return Cast<HloAllToAllInstruction>(this)->replica_groups();
+const std::vector<std::pair<int64, int64>>&
+HloInstruction::source_target_pairs() const {
+ return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
}
string HloInstruction::cross_replica_sum_barrier() const {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
- return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
- }
- return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier();
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
- return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
- barrier);
- }
- return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier(
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
barrier);
}
-tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const {
+absl::optional<int64> HloInstruction::all_reduce_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
}
@@ -3205,10 +3281,6 @@ const string& HloInstruction::custom_call_target() const {
return Cast<HloCustomCallInstruction>(this)->custom_call_target();
}
-const string& HloInstruction::channel_name() const {
- return Cast<HloHostComputeInstruction>(this)->channel_name();
-}
-
const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 8d8f149ee3..4a424cebc0 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -32,6 +32,10 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -45,10 +49,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -101,6 +103,7 @@ class HloPrintOptions {
return HloPrintOptions()
.set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_compact_operands(true)
.set_print_operand_shape(true)
.set_print_program_shape(false)
@@ -182,7 +185,7 @@ class HloPrintOptions {
return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
- bool print_backend_config() const { return print_metadata_; }
+ bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
@@ -220,7 +223,7 @@ class CanonicalNameMap {
return iter->second;
}
- string new_name = tensorflow::strings::StrCat("tmp_", index++);
+ string new_name = absl::StrCat("tmp_", index++);
canonical_name_map[old_name] = new_name;
return new_name;
}
@@ -347,7 +350,8 @@ class HloInstruction {
std::unique_ptr<Literal> literal);
// Creates an Iota instruction.
- static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape);
+ static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
+ int64 iota_dimension);
// Creates a get tuple element instruction.
static std::unique_ptr<HloInstruction> CreateGetTupleElement(
@@ -433,9 +437,10 @@ class HloInstruction {
//
// `reduction_computation`: the reduction function.
//
- // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
- // replicas belong to one group. Allreduce will be applied within subgroups.
- // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // `replica_groups`: each ReplicaGroup contains a list of replica id. If
+ // empty, all replicas belong to one group in the order of 0 - (n-1).
+ // Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// `all_reduce_id`: for Allreduce nodes from different modules, if they have
@@ -446,9 +451,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
+ const std::vector<ReplicaGroup>& replica_groups,
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// This op handles the communication of an Alltoall operation. On each core,
// the operands are N ops in the same shape, where N is the number of cores
@@ -463,12 +467,18 @@ class HloInstruction {
// within replica 1, 2, 3, and in the gather phase, the received blocks will
// be concatenated in the order of 1, 2, 3; another Alltoall will be applied
// within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
- //
- // TODO(b/110096724): This is NOT YET ready to use.
static std::unique_ptr<HloInstruction> CreateAllToAll(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier);
+ const std::vector<ReplicaGroup>& replica_groups);
+
+ // Creates a communitation instructions that permutes data cross replicas.
+ // Data is sent/received according to the (source_replica_id,
+ // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
+ // target_replica_id in any pair, the output on that replica is a tensor
+ // conssits of 0(s) in `shape`.
+ static std::unique_ptr<HloInstruction> CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
@@ -493,7 +503,7 @@ class HloInstruction {
// which is a TOKEN.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
+ HloInstruction* token_operand, absl::string_view outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
@@ -706,13 +716,7 @@ class HloInstruction {
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
-
- // Creates a HostCompute instruction, which records host-side control and
- // data dependencies for use in instruction scheduling.
- static std::unique_ptr<HloInstruction> CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
+ absl::string_view custom_call_target);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
@@ -766,7 +770,7 @@ class HloInstruction {
int64 operand_count() const { return operands_.size(); }
// Returns the vector of operands of this instruction.
- using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
+ using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
const InstructionVector& operands() const { return operands_; }
// Returns the vector of unique operands, in the same order they are found
@@ -863,6 +867,11 @@ class HloInstruction {
return false;
}
+ if (!ContainersEqual(precision_config_.operand_precision(),
+ other.precision_config_.operand_precision())) {
+ return false;
+ }
+
return IdenticalSlowPath(other, eq_computations);
}
@@ -1030,7 +1039,7 @@ class HloInstruction {
// Returns true if this instruction can be legally fused into a fusion
// instruction.
- bool IsFusable() const;
+ bool IsFusible() const;
// Returns the sharding applied to this operator.
// REQUIRES: has_sharding() is true.
@@ -1038,21 +1047,26 @@ class HloInstruction {
CHECK(has_sharding());
return *sharding_;
}
+ std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
+
// Returns the sharding applied to this operator, or default_ if none exists.
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
}
// Returns the sharding unique device, if any.
- tensorflow::gtl::optional<int64> sharding_unique_device() const {
+ absl::optional<int64> sharding_unique_device() const {
if (sharding_ == nullptr) {
- return tensorflow::gtl::optional<int64>();
+ return absl::optional<int64>();
}
return sharding_->UniqueDevice();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
- sharding_ = MakeUnique<HloSharding>(sharding);
+ sharding_ = std::make_shared<const HloSharding>(sharding);
+ }
+ void set_sharding(std::shared_ptr<const HloSharding> sharding) {
+ sharding_ = std::move(sharding);
}
void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
@@ -1088,19 +1102,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // TODO(b/80249101): Remove these methods once HLO scheduling and copy
- // insertion are integrated, and we don't need to run a separate pass
- // of copy elision anymore.
- bool CopyElisionAllowed() const {
- CHECK_EQ(HloOpcode::kCopy, opcode_);
- return copy_elision_allowed_;
- }
-
- void SetCopyElisionAllowed(bool value) {
- CHECK_EQ(HloOpcode::kCopy, opcode_);
- copy_elision_allowed_ = value;
- }
-
// Returns data on the dimension numbers used for a dot operation.
const DotDimensionNumbers& dot_dimension_numbers() const {
CHECK(dot_dimension_numbers_ != nullptr);
@@ -1110,6 +1111,9 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
+ // Returns the dump string of the precision configuration.
+ string PrecisionConfigToString() const;
+
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1253,6 +1257,20 @@ class HloInstruction {
static StatusOr<string> BackendConfigToRawString(
const tensorflow::protobuf::Message& proto);
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfigProto& precision_config() const {
+ return precision_config_;
+ }
+ void set_precision_config(const PrecisionConfigProto& precision_config) {
+ precision_config_ = precision_config;
+ }
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1421,18 +1439,18 @@ class HloInstruction {
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const;
- // Delegates to HloAllReduceInstruction::replica_group_ids.
- const std::vector<int64>& replica_group_ids() const;
-
- // Delegates to HloAllToAllInstruction::replica_groups.
+ // Delegates to HloCollectiveInstruction::replica_groups.
const std::vector<ReplicaGroup>& replica_groups() const;
+ // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);
// Delegates to HloAllReduceInstruction::all_reduce_id.
- tensorflow::gtl::optional<int64> all_reduce_id() const;
+ absl::optional<int64> all_reduce_id() const;
// Returns data on the window in a windowed operation such as
// convolution.
@@ -1475,9 +1493,6 @@ class HloInstruction {
// Delegates to HloCustomCallInstruction::custom_call_target.
const string& custom_call_target() const;
- // Delegates to HloHostComputeInstruction::channel_name.
- const string& channel_name() const;
-
// Delegates to HloPadInstruction::padding_config.
const PaddingConfig& padding_config() const;
@@ -1565,7 +1580,7 @@ class HloInstruction {
// NOTE: For all instructions other than kFusion, being elementwise on one of
// the operands is equivalent to being elementwise on all the operands.
virtual bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const;
+ const absl::optional<int64>& operand_idx) const;
// Prints an instruction to a string.
//
// The canonical string representation needs to name operands and instruction
@@ -1642,7 +1657,10 @@ class HloInstruction {
bool copy_elision_allowed_ = true;
// The sharding, if one exists.
- std::unique_ptr<HloSharding> sharding_;
+ // Uses std::shared_ptr to allow reuse of the same sharding object between
+ // HloInstructions and other components as HloSharding can be very large for
+ // many element tuples.
+ std::shared_ptr<const HloSharding> sharding_;
// Fields used by the kDomain instruction.
std::unique_ptr<DomainMetadata> operand_side_metadata_;
@@ -1661,6 +1679,10 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfigProto precision_config_;
+
// String identifier for instruction.
string name_;
@@ -1683,10 +1705,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string PrecisionToString(const PrecisionConfigProto::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
+StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 504b13043f..8b0b90dfb3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -53,7 +53,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
public:
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("not implemented %s",
- HloOpcodeString(hlo_instruction->opcode()).c_str());
+ HloOpcodeString(hlo_instruction->opcode()));
}
Status HandleParameter(HloInstruction* parameter) override {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 4fdf4360e6..ffc74cfedd 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -17,6 +17,12 @@ limitations under the License.
#include <deque>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -27,10 +33,10 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) {
@@ -89,7 +95,7 @@ HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
feature_index());
}
@@ -111,7 +117,7 @@ HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
@@ -133,7 +139,7 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormGradInstruction>(
+ return absl::make_unique<HloBatchNormGradInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
@@ -158,7 +164,7 @@ HloInstructionProto HloFftInstruction::ToProto() const {
std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("fft_type=", FftType_Name(fft_type())),
- StrCat("fft_length={", Join(fft_length(), ","), "}")};
+ StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
}
bool HloFftInstruction::IdenticalSlowPath(
@@ -175,8 +181,8 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_,
- fft_length_);
+ return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
+ fft_length_);
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
@@ -230,8 +236,8 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
- channel_id(), is_host_transfer());
+ return absl::make_unique<HloSendInstruction>(
+ new_operands[0], new_operands[1], channel_id(), is_host_transfer());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
@@ -248,7 +254,7 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendDoneInstruction>(
+ return absl::make_unique<HloSendDoneInstruction>(
Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
}
@@ -269,7 +275,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvInstruction>(
+ return absl::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
is_host_transfer());
}
@@ -291,31 +297,67 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvDoneInstruction>(
+ return absl::make_unique<HloRecvDoneInstruction>(
Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
+HloCollectiveInstruction::HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups)
+ : HloInstruction(opcode, shape), replica_groups_(replica_groups) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloInstructionProto HloCollectiveInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_replica_groups() = {replica_groups_.begin(),
+ replica_groups_.end()};
+ return proto;
+}
+
+std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> replica_group_str;
+ for (const ReplicaGroup& group : replica_groups()) {
+ replica_group_str.push_back(
+ StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
+ }
+ result.push_back(
+ StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}"));
+ return result;
+}
+
+bool HloCollectiveInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectiveInstruction&>(other);
+ return ContainersEqual(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return ContainersEqual(a.replica_ids(),
+ b.replica_ids());
+ });
+}
+
HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id)
- : HloInstruction(HloOpcode::kCrossReplicaSum, shape),
- replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
+ const absl::optional<int64>& all_reduce_id)
+ : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands,
+ replica_groups),
+ cross_replica_sum_barrier_(barrier),
all_reduce_id_(all_reduce_id) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
AppendComputation(reduce_computation);
}
HloInstructionProto HloAllReduceInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- for (int64 i : replica_group_ids_) {
- proto.add_replica_group_ids(i);
- }
+ HloInstructionProto proto = HloCollectiveInstruction::ToProto();
// Proto3 is so sad.
if (all_reduce_id_) {
proto.set_all_reduce_id(*all_reduce_id_);
@@ -325,9 +367,9 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
- const HloPrintOptions& /*options*/) const {
- std::vector<string> result = {
- StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")};
+ const HloPrintOptions& options) const {
+ std::vector<string> result =
+ HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (!cross_replica_sum_barrier().empty()) {
result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
}
@@ -342,7 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
- return replica_group_ids() == casted_other.replica_group_ids() &&
+ return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
eq_computations(to_apply(), casted_other.to_apply()) &&
cross_replica_sum_barrier() ==
casted_other.cross_replica_sum_barrier() &&
@@ -354,70 +396,76 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllReduceInstruction>(
- shape, new_operands, to_apply(), replica_group_ids(),
+ return absl::make_unique<HloAllReduceInstruction>(
+ shape, new_operands, to_apply(), replica_groups(),
cross_replica_sum_barrier(), all_reduce_id());
}
HloAllToAllInstruction::HloAllToAllInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier)
- : HloInstruction(HloOpcode::kAllToAll, shape),
- replica_groups_(replica_groups),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-bool HloAllToAllInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- }) &&
- cross_replica_sum_barrier() ==
- casted_other.cross_replica_sum_barrier();
-}
+ const std::vector<ReplicaGroup>& replica_groups)
+ : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
+ replica_groups) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllToAllInstruction>(
- shape, new_operands, replica_groups(), cross_replica_sum_barrier());
+ return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
+ replica_groups());
}
-std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
- const HloPrintOptions& options) const {
- std::vector<string> result;
- std::vector<string> replica_group_str;
- for (const ReplicaGroup& group : replica_groups()) {
- replica_group_str.push_back(
- StrCat("{", Join(group.replica_ids(), ","), "}"));
- }
- result.push_back(
- StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs)
+ : HloInstruction(HloOpcode::kCollectivePermute, shape),
+ source_target_pairs_(source_target_pairs) {
+ AppendOperand(operand);
+}
- if (!cross_replica_sum_barrier().empty()) {
- result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (const auto& pair : source_target_pairs()) {
+ auto* proto_pair = proto.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
}
+ return proto;
+}
+std::vector<string>
+HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> strs;
+ for (const auto& pair : source_target_pairs()) {
+ strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
+ }
+ result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
return result;
}
-HloInstructionProto HloAllToAllInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_replica_groups() = {replica_groups_.begin(),
- replica_groups_.end()};
- proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
- return proto;
+bool HloCollectivePermuteInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectivePermuteInstruction&>(other);
+ return ContainersEqual(
+ source_target_pairs(), casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) {
+ return a == b;
+ });
+}
+
+std::unique_ptr<HloInstruction>
+HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, new_operands[0], source_target_pairs());
}
HloReverseInstruction::HloReverseInstruction(
@@ -438,7 +486,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const {
std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReverseInstruction::IdenticalSlowPath(
@@ -454,8 +502,8 @@ std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReverseInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
+ dimensions());
}
HloConcatenateInstruction::HloConcatenateInstruction(
@@ -477,7 +525,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const {
std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloConcatenateInstruction::IdenticalSlowPath(
@@ -494,8 +542,8 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConcatenateInstruction>(shape, new_operands,
- dimensions(0));
+ return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
+ dimensions(0));
}
HloReduceInstruction::HloReduceInstruction(
@@ -520,7 +568,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const {
std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReduceInstruction::IdenticalSlowPath(
@@ -539,8 +587,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
- to_apply());
+ return absl::make_unique<HloReduceInstruction>(shape, new_operands,
+ dimensions(), to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
@@ -563,7 +611,7 @@ HloInstructionProto HloSortInstruction::ToProto() const {
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloSortInstruction::IdenticalSlowPath(
@@ -580,7 +628,8 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
- return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys,
+ values);
}
HloTransposeInstruction::HloTransposeInstruction(
@@ -595,7 +644,7 @@ HloTransposeInstruction::HloTransposeInstruction(
Permute(dimensions, shape.dimensions()).begin()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << Join(dimensions, ", ") << "}";
+ << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -616,7 +665,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const {
std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloTransposeInstruction::IdenticalSlowPath(
@@ -633,8 +682,8 @@ HloTransposeInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloTransposeInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
+ dimensions());
}
HloBroadcastInstruction::HloBroadcastInstruction(
@@ -655,7 +704,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const {
std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloBroadcastInstruction::IdenticalSlowPath(
@@ -672,8 +721,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
+ dimensions());
}
HloMapInstruction::HloMapInstruction(
@@ -699,7 +748,7 @@ HloInstructionProto HloMapInstruction::ToProto() const {
}
bool HloMapInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!dimensions().empty()) {
// Check that the map is executed in elementwise compatible dimensions.
if (dimensions().size() != shape().dimensions_size()) {
@@ -716,7 +765,7 @@ bool HloMapInstruction::IsElementwiseImpl(
std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloMapInstruction::IdenticalSlowPath(
@@ -730,7 +779,7 @@ std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply());
+ return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
}
HloSliceInstruction::HloSliceInstruction(
@@ -774,7 +823,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
}
- return {StrCat("slice={", Join(bounds, ", "), "}")};
+ return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
}
bool HloSliceInstruction::IdenticalSlowPath(
@@ -792,8 +841,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
- slice_limits_, slice_strides_);
+ return absl::make_unique<HloSliceInstruction>(
+ shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
@@ -812,7 +861,7 @@ HloInstructionProto HloConstantInstruction::ToProto() const {
}
bool HloConstantInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
@@ -845,7 +894,7 @@ HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
+ return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -860,7 +909,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
// lines. Compact this into one line by stripping out white space.
string tmp = literal().ToString();
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
- std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
+ std::vector<string> v = absl::StrSplit(tmp, ' ');
bool first = true;
// Concatenate elements in "v" with spaces separating them, but ignoring
// empty entries.
@@ -952,7 +1001,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const {
}
bool HloFusionInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!operand_idx.has_value()) {
for (auto* fused : fused_instructions()) {
if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
@@ -1155,7 +1204,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal(
HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
HloInstruction* instruction_to_fuse, bool add_output) {
- CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
+ CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
HloInstruction* clone = nullptr;
if (called_computations().empty()) {
@@ -1339,8 +1388,8 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation = module->AddEmbeddedComputation(
fused_instructions_computation()->Clone("clone", context));
}
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind(), new_operands,
- new_fused_computation);
+ return absl::make_unique<HloFusionInstruction>(
+ shape, fusion_kind(), new_operands, new_fused_computation);
}
Status HloFusionInstruction::DeduplicateFusionOperands() {
@@ -1384,7 +1433,7 @@ std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
}
bool HloRngInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
@@ -1399,7 +1448,8 @@ std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloRngInstruction>(shape, distribution_, new_operands);
+ return absl::make_unique<HloRngInstruction>(shape, distribution_,
+ new_operands);
}
HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
@@ -1435,7 +1485,8 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name());
+ return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
+ name());
}
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
@@ -1471,8 +1522,8 @@ HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0],
- tuple_index());
+ return absl::make_unique<HloGetTupleElementInstruction>(
+ shape, new_operands[0], tuple_index());
}
HloReducePrecisionInstruction::HloReducePrecisionInstruction(
@@ -1514,7 +1565,7 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, new_operands[0], exponent_bits(), mantissa_bits());
}
@@ -1555,16 +1606,17 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
- infeed_config());
+ return absl::make_unique<HloInfeedInstruction>(
+ infeed_shape(), new_operands[0], infeed_config());
}
-HloOutfeedInstruction::HloOutfeedInstruction(
- const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config)
+HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ HloInstruction* token_operand,
+ absl::string_view outfeed_config)
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
- outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ outfeed_config_(outfeed_config) {
CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
<< "Outfeed shape " << outfeed_shape
<< " must be compatible with operand shape " << operand->shape();
@@ -1600,8 +1652,8 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
- new_operands[1], outfeed_config());
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
}
HloConvolutionInstruction::HloConvolutionInstruction(
@@ -1671,7 +1723,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloConvolutionInstruction>(
+ return absl::make_unique<HloConvolutionInstruction>(
shape, new_operands[0], new_operands[1], window(),
convolution_dimension_numbers_, feature_group_count_);
}
@@ -1716,7 +1768,7 @@ HloReduceWindowInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceWindowInstruction>(
+ return absl::make_unique<HloReduceWindowInstruction>(
shape, new_operands[0], new_operands[1], window(), to_apply());
}
@@ -1765,14 +1817,14 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, new_operands[0], select(), window(), new_operands[1],
new_operands[2], scatter());
}
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target)
+ absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(),
custom_call_target.end()) {
@@ -1840,8 +1892,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- auto cloned = MakeUnique<HloCustomCallInstruction>(shape, new_operands,
- custom_call_target());
+ auto cloned = absl::make_unique<HloCustomCallInstruction>(
+ shape, new_operands, custom_call_target());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
@@ -1851,41 +1903,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
return std::move(cloned);
}
-HloHostComputeInstruction::HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns)
- : HloInstruction(HloOpcode::kHostCompute, shape),
- channel_name_(channel_name.begin(), channel_name.end()),
- cost_estimate_ns_(cost_estimate_ns) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-HloInstructionProto HloHostComputeInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- proto.set_channel_name(channel_name_);
- proto.set_cost_estimate_ns(cost_estimate_ns_);
- return proto;
-}
-
-bool HloHostComputeInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- // Not yet supported.
- return false;
-}
-
-std::unique_ptr<HloInstruction>
-HloHostComputeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const {
- return MakeUnique<HloHostComputeInstruction>(
- shape, new_operands, channel_name_, cost_estimate_ns_);
-}
-
HloPadInstruction::HloPadInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* padding_value,
@@ -1920,8 +1937,8 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloPadInstruction>(shape, new_operands[0], new_operands[1],
- padding_config_);
+ return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
+ new_operands[1], padding_config_);
}
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
@@ -1943,8 +1960,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {
- StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")};
+ return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
+ "}")};
}
bool HloDynamicSliceInstruction::IdenticalSlowPath(
@@ -1960,7 +1977,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloDynamicSliceInstruction>(
+ return absl::make_unique<HloDynamicSliceInstruction>(
shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
}
@@ -1972,25 +1989,25 @@ HloGatherInstruction::HloGatherInstruction(
AppendOperand(operand);
AppendOperand(start_indices);
gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
+ absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
+ absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
}
string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
string offset_dims =
StrCat("offset_dims={",
- Join(gather_dimension_numbers_->offset_dims(), ","), "}");
- string collapsed_slice_dims =
- StrCat("collapsed_slice_dims={",
- Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims = StrCat(
+ "collapsed_slice_dims={",
+ StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
string start_index_map =
StrCat("start_index_map={",
- Join(gather_dimension_numbers_->start_index_map(), ","), "}");
+ StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
- return Join<std::initializer_list<string>>(
+ return StrJoin<std::initializer_list<string>>(
{offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
@@ -2027,7 +2044,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")};
+ StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2046,7 +2063,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloGatherInstruction>(
+ return absl::make_unique<HloGatherInstruction>(
shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
gather_slice_sizes());
}
@@ -2062,24 +2079,24 @@ HloScatterInstruction::HloScatterInstruction(
AppendOperand(updates);
AppendComputation(update_computation);
scatter_dimension_numbers_ =
- MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+ absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
}
string HloScatterInstruction::ScatterDimensionNumbersToString() const {
- string update_window_dims =
- StrCat("update_window_dims={",
- Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string update_window_dims = StrCat(
+ "update_window_dims={",
+ StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}");
string inserted_window_dims = StrCat(
"inserted_window_dims={",
- Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
string scatter_dims_to_operand_dims = StrCat(
"scatter_dims_to_operand_dims={",
- Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
"}");
string index_vector_dim = StrCat(
"index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
- return Join<std::initializer_list<string>>(
+ return StrJoin<std::initializer_list<string>>(
{update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
index_vector_dim},
", ");
@@ -2133,9 +2150,39 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloScatterInstruction>(
+ return absl::make_unique<HloScatterInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
scatter_dimension_numbers());
}
+HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
+ : HloInstruction(HloOpcode::kIota, shape),
+ iota_dimension_(iota_dimension) {}
+
+HloInstructionProto HloIotaInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.add_dimensions(iota_dimension());
+ return proto;
+}
+
+std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("iota_dimension=", iota_dimension())};
+}
+
+bool HloIotaInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
+ return iota_dimension() == casted_other.iota_dimension();
+}
+
+std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 803dbeabeb..ee6e337b6a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@@ -217,19 +218,37 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
HloCloneContext* context) const override;
};
-class HloAllReduceInstruction : public HloInstruction {
+class HloCollectiveInstruction : public HloInstruction {
+ public:
+ const std::vector<ReplicaGroup>& replica_groups() const {
+ return replica_groups_;
+ }
+
+ protected:
+ explicit HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
+
+ HloInstructionProto ToProto() const override;
+
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ std::vector<ReplicaGroup> replica_groups_;
+};
+
+class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
-
- // Returns the group ids of each replica for CrossReplicaSum op.
- const std::vector<int64>& replica_group_ids() const {
- return replica_group_ids_;
- }
+ const std::vector<ReplicaGroup>& replica_groups,
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// Returns the barrier config used for the CrossReplicaSum implementation of
// each backend.
@@ -240,9 +259,7 @@ class HloAllReduceInstruction : public HloInstruction {
cross_replica_sum_barrier_ = barrier;
}
- tensorflow::gtl::optional<int64> all_reduce_id() const {
- return all_reduce_id_;
- }
+ absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -261,37 +278,40 @@ class HloAllReduceInstruction : public HloInstruction {
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const override;
- // The group id of each replica for CrossReplicaSum.
- std::vector<int64> replica_group_ids_;
-
// The string representation of the barrier config used for CrossReplicaSum.
string cross_replica_sum_barrier_;
// For Allreduce nodes from different modules, if they have the same
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
- tensorflow::gtl::optional<int64> all_reduce_id_;
+ absl::optional<int64> all_reduce_id_;
};
-class HloAllToAllInstruction : public HloInstruction {
+class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
- const std::vector<ReplicaGroup>& replica_groups() const {
- return replica_groups_;
- }
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
- // TODO(b/110096724): rename this.
- void set_cross_replica_sum_barrier(string barrier) {
- cross_replica_sum_barrier_ = barrier;
- }
- string cross_replica_sum_barrier() const {
- return cross_replica_sum_barrier_;
+class HloCollectivePermuteInstruction : public HloInstruction {
+ public:
+ explicit HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
+ return source_target_pairs_;
}
+ // Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
@@ -308,10 +328,7 @@ class HloAllToAllInstruction : public HloInstruction {
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const override;
- std::vector<ReplicaGroup> replica_groups_;
-
- // The string representation of the barrier config.
- string cross_replica_sum_barrier_;
+ const std::vector<std::pair<int64, int64>> source_target_pairs_;
};
class HloReverseInstruction : public HloInstruction {
@@ -507,7 +524,7 @@ class HloMapInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -600,7 +617,7 @@ class HloConstantInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
@@ -751,7 +768,7 @@ class HloFusionInstruction : public HloInstruction {
bool add_output = false);
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -780,7 +797,7 @@ class HloRngInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -920,7 +937,7 @@ class HloOutfeedInstruction : public HloInstruction {
explicit HloOutfeedInstruction(const Shape& outfeed_shape,
HloInstruction* operand,
HloInstruction* token_operand,
- tensorflow::StringPiece outfeed_config);
+ absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
@@ -1073,14 +1090,14 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
+ absl::string_view custom_call_target);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
}
void set_window(const Window& window) override {
- window_ = MakeUnique<Window>(window);
+ window_ = absl::make_unique<Window>(window);
}
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -1091,7 +1108,7 @@ class HloCustomCallInstruction : public HloInstruction {
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(dnums);
+ absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
// Returns a serialized representation of this instruction.
@@ -1117,33 +1134,6 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
};
-class HloHostComputeInstruction : public HloInstruction {
- public:
- explicit HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
- // Returns the channel name associated with the instruction. The name is
- // used to identify host Send/Recv operations.
- const string& channel_name() const { return channel_name_; }
- // Returns a serialized representation of this instruction.
- HloInstructionProto ToProto() const override;
-
- private:
- bool IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const override;
- // Implementation for non-common logic of CloneWithNewOperands.
- std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const override;
- // Name to use for host send/recv channels.
- string channel_name_;
- // Estimate of the duration of a host computation in nanoseconds.
- int64 cost_estimate_ns_ = 0;
-};
-
class HloPadInstruction : public HloInstruction {
public:
explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
@@ -1289,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction {
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
};
+class HloIotaInstruction : public HloInstruction {
+ public:
+ explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ int64 iota_dimension() const { return iota_dimension_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ const int64 iota_dimension_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 8e0d38b6a6..8350285e67 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -17,20 +17,20 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-
-using ::tensorflow::StringPiece;
-
namespace {
+using absl::string_view;
+
constexpr int kEOF = -1;
constexpr int kError = -2;
@@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
-tensorflow::StringPiece HloLexer::StringPieceFromPointers(
- const char* begin, const char* end) const {
+absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
+ const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
- return tensorflow::StringPiece(begin, end - begin);
+ return absl::string_view(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
@@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kAttributeName;
}
- tensorflow::StringPiece identifier =
+ absl::string_view identifier =
StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
@@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() {
}
}
- str_val_ = std::string(identifier);
+ str_val_ = string(identifier);
return TokKind::kIdent;
}
@@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
- &decimal_val_);
+ CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_));
return TokKind::kDecimal;
}
@@ -339,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() {
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
auto slice = StringPieceFromPointers(token_start_, current_ptr_);
- if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
+ if (absl::SimpleAtoi(slice, &int64_val_)) {
return TokKind::kInt;
}
LOG(ERROR) << "Failed to parse int literal: " << slice;
@@ -365,6 +364,7 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no = line_no_cache_.line_no_of_query;
}
for (; ptr != location; ptr++) {
+ CHECK_LT(ptr, buf_.end());
if (*ptr == '\n') {
line_no++;
}
@@ -374,24 +374,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no_cache_.last_query = ptr;
line_no_cache_.line_no_of_query = line_no;
size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
- if (line_offset == tensorflow::StringPiece::npos) {
+ if (line_offset == absl::string_view::npos) {
line_offset = 0;
}
return {line_no, ptr - start - line_offset};
}
-tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
+absl::string_view HloLexer::GetLine(LocTy loc) const {
if (!CanDereference(loc)) {
return "LINE OUT OF RANGE";
}
size_t line_start =
StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
- const char* start = line_start == tensorflow::StringPiece::npos
+ const char* start = line_start == absl::string_view::npos
? buf_.begin()
: buf_.begin() + line_start + 1;
size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
const char* end =
- line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end;
+ line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
return StringPieceFromPointers(start, end);
}
@@ -403,10 +403,14 @@ TokKind HloLexer::LexString() {
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
if (RE2::Consume(&consumable, *escaping_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::StringPiece raw =
+ absl::string_view raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
- if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
+ // TODO(b/113077997): Change to absl::CUnescape once it works properly with
+ // copy-on-write std::string implementations.
+ if (!tensorflow::str_util::CUnescape( // non-absl ok
+ tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok
+ &str_val_, &error)) {
LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
return TokKind::kError;
}
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index 003ac34ace..3e2f8bcd52 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
@@ -34,7 +34,7 @@ namespace xla {
// it directly.
class HloLexer {
public:
- explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
+ explicit HloLexer(absl::string_view buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
@@ -77,7 +77,7 @@ class HloLexer {
std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const;
// Returns the whole line given the location.
- tensorflow::StringPiece GetLine(LocTy loc) const;
+ absl::string_view GetLine(LocTy loc) const;
private:
// Returns the current character. If it's neither the end of input buffer nor
@@ -89,8 +89,8 @@ class HloLexer {
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
- tensorflow::StringPiece StringPieceFromPointers(const char* begin,
- const char* end) const;
+ absl::string_view StringPieceFromPointers(const char* begin,
+ const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
@@ -107,11 +107,11 @@ class HloLexer {
TokKind LexNumberOrPattern();
TokKind LexString();
- const tensorflow::StringPiece buf_;
+ const absl::string_view buf_;
const char* current_ptr_;
// Information about the current token.
- const char* token_start_;
+ const char* token_start_ = nullptr;
TokKind current_kind_;
string str_val_;
Shape shape_val_;
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 43c41ece6e..3a1dd471c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -17,8 +17,9 @@ limitations under the License.
#include <deque>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -29,17 +30,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
using Worklist = std::deque<const HloInstruction*>;
using Workset = std::unordered_set<const HloInstruction*>;
-namespace {
-
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
Workset* workset) {
if (workset->count(instruction) == 0) {
@@ -296,7 +294,7 @@ StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module));
+ auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module));
liveness_analysis->RunAnalysis();
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 7e4b883435..5269cad94d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -15,15 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace testing {
-using ::tensorflow::str_util::Join;
-
bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
@@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong lhs_contracting_dimensions (got {"
- << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {"
- << lhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
+ << "} want {" << lhs_contracting_dim_ << "})";
return false;
}
@@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong rhs_contracting_dimensions (got {"
- << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {"
- << rhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
+ << "} want {" << rhs_contracting_dim_ << "})";
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index c577b4359a..5502e565b6 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
namespace testing {
@@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher
class HloShardingMatcher
: public ::testing::MatcherInterface<const HloInstruction*> {
public:
- explicit HloShardingMatcher(
- const tensorflow::gtl::optional<HloSharding>& sharding)
+ explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
: sharding_(sharding) {}
bool MatchAndExplain(const HloInstruction* instruction,
@@ -129,7 +128,7 @@ class HloShardingMatcher
void DescribeTo(std::ostream* os) const override;
private:
- tensorflow::gtl::optional<HloSharding> sharding_;
+ absl::optional<HloSharding> sharding_;
};
// Matches a Dot HLO instruction with specific LHS and RHS contracting
@@ -189,6 +188,7 @@ HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
+HLO_MATCHER(Iota);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
HLO_MATCHER(Le);
@@ -307,7 +307,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -317,7 +317,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
new ::xla::testing::HloShapeAndLayoutMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -330,14 +330,14 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
}
// Matcher for Sharding from sharding string
inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
- tensorflow::StringPiece sharding) {
+ absl::string_view sharding) {
return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
ParseSharding(sharding).ValueOrDie()));
}
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
return ::testing::MakeMatcher(
- new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
+ new ::xla::testing::HloShardingMatcher(absl::nullopt));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 55ff073d3f..78167335c8 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -22,12 +22,13 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -274,7 +275,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(entry != nullptr);
- auto module = MakeUnique<HloModule>(proto.name(), module_config);
+ auto module = absl::make_unique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
@@ -409,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
string error_message =
"The subcomputation to outline has multiple outputs:\n";
for (HloInstruction* output : outputs) {
- tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
+ absl::StrAppend(&error_message, output->ToString(), "\n");
}
LOG(FATAL) << error_message;
}
@@ -507,7 +508,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
- auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
+ auto module = absl::make_unique<HloModule>(name_ + "-" + suffix, config_);
HloCloneContext context(module.get(), suffix);
auto cloned_computation = entry_computation_->Clone(suffix, &context);
@@ -535,12 +536,11 @@ uint64 HloModule::RandomNew64() const {
return rng_();
}
-HloComputation* HloModule::GetComputationWithName(
- tensorflow::StringPiece name) {
+HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
auto computations_in_module = computations();
- auto it = c_find_if(computations_in_module, [&](HloComputation* computation) {
- return computation->name() == name;
- });
+ auto it = absl::c_find_if(
+ computations_in_module,
+ [&](HloComputation* computation) { return computation->name() == name; });
return it == computations_in_module.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index d2e726a0db..cf129b835d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
@@ -142,7 +142,7 @@ class HloModule {
// Returns the computation in this module that has the name `name`. Returns
// null if there is no such computation.
- HloComputation* GetComputationWithName(tensorflow::StringPiece name);
+ HloComputation* GetComputationWithName(absl::string_view name);
// Gets the number of computations in this module.
int64 computation_count() const { return computations_.size(); }
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 07a8c798db..9bfa3a5f45 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -18,15 +18,15 @@ limitations under the License.
#include <atomic>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
bool ignore_layouts)
@@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout(
}
string HloModuleConfig::compilation_cache_key() const {
- string key =
- tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled());
+ string key = absl::StrCat("profiling=", hlo_profiling_enabled());
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
- StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
+ StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 074e9c9070..3f1e1cc73e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -18,11 +18,11 @@ limitations under the License.
#include <string>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -72,15 +72,6 @@ class HloModuleConfig {
return debug_options_.xla_hlo_profile();
}
- // Sets/returns whether this is a "host module". Host modules are used to
- // record the data- and control-flow dependencies of host side computation
- // that communicates with compiled code. They are used for analysis and
- // scheduling purposes, but no code is generated.
- bool is_host_module() const { return is_host_module_; }
- void set_is_host_module(bool is_host_module) {
- is_host_module_ = is_host_module;
- }
-
// Sets/returns the module seed set during execution.
void set_seed(uint64 seed) { seed_ = seed; }
uint64 seed() const { return seed_; }
@@ -113,7 +104,7 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
+ absl::optional<ComputationLayout> entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 29024085c1..12ca2340a6 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -31,7 +31,7 @@ namespace xla {
class HloModuleDCE : public HloPassInterface {
public:
~HloModuleDCE() override {}
- tensorflow::StringPiece name() const override { return "hlo-module-dce"; }
+ absl::string_view name() const override { return "hlo-module-dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 10bf9ffd6c..9c01862a4b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -19,9 +19,10 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -59,7 +60,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
- auto metadata = MakeUnique<HloModuleGroupMetadata>(modules);
+ auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
TF_RETURN_IF_ERROR(metadata->Build());
return std::move(metadata);
}
@@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() {
if (VLOG_IS_ON(4)) {
DumpCollectedStats();
}
+
+ for (HloModule* module : modules_) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ points_to_analyses_[module] = std::move(points_to_analysis);
+ }
+
return Status::OK();
}
@@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const {
ss << " " << hlo->name() << std::endl;
}
ss << "has multiple instructions on the same device";
- return FailedPrecondition("%s", ss.str().c_str());
+ return FailedPrecondition("%s", ss.str());
}
}
}
@@ -204,6 +213,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
return channels_[channel_id_map_.at(channel_id)];
}
+bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
+ return channel_id_map_.find(channel_id) != channel_id_map_.end();
+}
+
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));
@@ -267,15 +280,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
-tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
+absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
const HloInstruction& instruction) const {
// The module group metadata can be created in both "single module, multiple
// devices" and "multiple modules, no explicit devices" fashions.
// The API returns an optional even though the current implementation always
// returns a device, to account for cases where we cannot guess a device.
// In such cases the VerifyChannelInstructions() will return proper errors.
- tensorflow::gtl::optional<int64> device =
- instruction.sharding_unique_device();
+ absl::optional<int64> device = instruction.sharding_unique_device();
if (!device) {
device = GetModuleId(instruction.parent()->parent());
}
@@ -283,10 +295,7 @@ tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
}
int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
- return std::count_if(modules_.begin(), modules_.end(),
- [](const HloModule* module) {
- return !module->config().is_host_module();
- });
+ return modules_.size();
}
Status HloModuleGroupMetadata::RecordInstructions() {
@@ -383,7 +392,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- tensorflow::MakeUnique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::unordered_set<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
companion_set->insert(instruction1);
companion_set->insert(instruction2);
@@ -411,16 +420,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
Status HloModuleGroupMetadata::VerifyChannelInstructions() {
for (const Channel& channel : channels_) {
if (channel.send == nullptr) {
- return FailedPrecondition("missing send for id : %lld", channel.id);
+ return FailedPrecondition("missing send for id : %d", channel.id);
}
if (channel.recv == nullptr) {
- return FailedPrecondition("missing recv for id : %lld", channel.id);
+ return FailedPrecondition("missing recv for id : %d", channel.id);
}
if (channel.send_done == nullptr) {
- return FailedPrecondition("missing send-done for id : %lld", channel.id);
+ return FailedPrecondition("missing send-done for id : %d", channel.id);
}
if (channel.recv_done == nullptr) {
- return FailedPrecondition("missing recv-done for id : %lld", channel.id);
+ return FailedPrecondition("missing recv-done for id : %d", channel.id);
}
}
@@ -436,33 +445,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
auto send_done_device = GetInstructionDevice(*channel.send_done);
if (!send_device) {
return FailedPrecondition("send instruction must have a device: %s",
- channel.send->ToString().c_str());
+ channel.send->ToString());
}
if (!send_done_device) {
return FailedPrecondition("send_done instruction must have a device: %s",
- channel.send_done->ToString().c_str());
+ channel.send_done->ToString());
}
if (*send_device != *send_done_device) {
return FailedPrecondition(
- "send and send-done (channel=%lld) must be on the same device: %lld "
- "vs. %lld",
+ "send and send-done (channel=%d) must be on the same device: %d "
+ "vs. %d",
channel.id, *send_device, *send_done_device);
}
auto recv_device = GetInstructionDevice(*channel.recv);
auto recv_done_device = GetInstructionDevice(*channel.recv_done);
if (!recv_done_device) {
return FailedPrecondition("recv_done instruction must have a device: %s",
- channel.recv_done->ToString().c_str());
+ channel.recv_done->ToString());
}
if (*recv_device != *recv_done_device) {
return FailedPrecondition(
- "recv and recv-done (channel=%lld) must be on the same device: %lld "
- "vs. %lld",
+ "recv and recv-done (channel=%d) must be on the same device: %d "
+ "vs. %d",
channel.id, *recv_device, *recv_done_device);
}
if (*send_device == *recv_device) {
return FailedPrecondition(
- "send and recv (channel=%lld) must be on different devices: %lld",
+ "send and recv (channel=%d) must be on different devices: %d",
channel.id, *send_device);
}
}
@@ -483,7 +492,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
!CheckCompanionPathsCompatibility(
path, GetCompanionsPath(channel.recv_done))) {
return FailedPrecondition(
- "Nest companion paths do not match for channel %lld", channel.id);
+ "Nest companion paths do not match for channel %d", channel.id);
}
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 1b256cd00e..768b0c7eb3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -22,14 +22,15 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -125,6 +126,9 @@ class HloModuleGroupMetadata {
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
+ // Returns if the given channel id exists in metadata.
+ bool HasChannel(int64 channel_id) const;
+
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;
@@ -156,7 +160,7 @@ class HloModuleGroupMetadata {
// Retrieves the device an instruction is assigned to. Either from the
// sharding information, or from the ordinal of the module the instruction
// is in.
- tensorflow::gtl::optional<int64> GetInstructionDevice(
+ absl::optional<int64> GetInstructionDevice(
const HloInstruction& instruction) const;
// Returns the number of modules for devices (excluding the host module).
@@ -194,6 +198,10 @@ class HloModuleGroupMetadata {
// Returns the maximum channel id or all_reduce_id used in the module group.
int64 max_channel_id() const { return max_channel_id_; }
+ TuplePointsToAnalysis* points_to_analysis(HloModule* module) const {
+ return points_to_analyses_.at(module).get();
+ }
+
private:
Status Build();
@@ -268,6 +276,9 @@ class HloModuleGroupMetadata {
// The modules that this metadata was built from.
const std::vector<HloModule*>& modules_;
+
+ tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
+ points_to_analyses_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 0dc5676148..d70328c8a3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,7 +22,10 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -30,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
add_unique_predecessor(control_predecessor);
}
}
- if (instruction->opcode() == HloOpcode::kRecvDone) {
+ if (instruction->opcode() == HloOpcode::kRecvDone &&
+ !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_predecessor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
add_unique_successor(control_successor);
}
}
- if (instruction->opcode() == HloOpcode::kRecv) {
+ if (instruction->opcode() == HloOpcode::kRecv &&
+ !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_successor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -264,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
string cyclic_instructions;
for (const auto& state : *visit_state) {
if (state.second == VisitState::kVisiting) {
- tensorflow::strings::StrAppend(&cyclic_instructions,
- state.first->ToString(), "\n");
+ absl::StrAppend(&cyclic_instructions, state.first->ToString(),
+ "\n");
}
}
// TODO(b/64305524): Improve the error message to print out the
@@ -276,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
"following nodes. Note that the order of the nodes is arbitrary "
"and that the list may include nodes that are not part of the "
"cycle.\n%s",
- predecessor->ToString().c_str(), cyclic_instructions.c_str());
+ predecessor->ToString(), cyclic_instructions);
}
stack.push(predecessor);
}
@@ -332,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability(
TF_RETURN_IF_ERROR(
VisitTopologicalOrder(&visit_states, visit_function, root));
}
- auto reachability = MakeUnique<HloReachabilityMap>(post_order);
+ auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
for (HloInstruction* hlo : post_order) {
reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 236f450086..209ad5e58c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index d1eaf35785..2d4e38589f 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -39,7 +39,7 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
});
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
- return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
+ return InvalidArgument("Unknown opcode: %s", opcode_name);
}
return it->second;
}
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index ec279867e5..e6bfb8025d 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -58,6 +58,7 @@ namespace xla {
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
+ V(kCollectivePermute, "collective-permute") \
V(kClz, "count-leading-zeros") \
V(kComplex, "complex") \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
@@ -85,7 +86,6 @@ namespace xla {
V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
- V(kHostCompute, "host-compute") \
V(kImag, "imag") \
V(kInfeed, "infeed") \
V(kIota, "iota") \
@@ -156,7 +156,7 @@ enum HloOpcodeProperty {
// Returns a string representation of the opcode.
string HloOpcodeString(HloOpcode opcode);
-// Returns a string representation of the opcode.
+// Retrieves the opcode enum by name if the opcode exists.
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name);
inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 6c1e015f77..0581d5c404 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -25,8 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore(
}
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
+ if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
+ use.instruction)) {
+ continue;
+ }
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
VLOG(4) << "use of " << a << " (" << use << ") not before " << b
<< " is defined";
@@ -302,22 +306,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
std::vector<string> pieces;
pieces.push_back(name);
for (auto* computation : module_->MakeNonfusionComputations()) {
- pieces.push_back(tensorflow::strings::Printf("computation %s:",
- computation->name().c_str()));
+ pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
const auto all = computation->MakeInstructionPostOrder();
for (auto instruction : all) {
- pieces.push_back(tensorflow::strings::Printf(
- " %s predecessors:", instruction->name().c_str()));
+ pieces.push_back(
+ absl::StrFormat(" %s predecessors:", instruction->name()));
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", predecessor->name().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
@@ -368,8 +370,8 @@ string SequentialHloOrdering::ToString() const {
std::vector<string> pieces;
pieces.push_back("SequentialHloOrdering");
for (auto* computation : module_->computations()) {
- pieces.push_back(tensorflow::strings::Printf("computation %s order:",
- computation->name().c_str()));
+ pieces.push_back(
+ absl::StrFormat("computation %s order:", computation->name()));
// Gather all instructions in the module sequence for this computation and
// sort them by their position.
std::vector<const HloInstruction*> instructions;
@@ -384,11 +386,10 @@ string SequentialHloOrdering::ToString() const {
return order_position_.at(a) < order_position_.at(b);
});
for (auto instruction : instructions) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", instruction->name().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", instruction->name()));
}
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
std::ostream& operator<<(
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ab57a8b07f..eae4508b24 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -15,6 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -24,21 +30,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace {
-using ::tensorflow::StringPiece;
-using ::tensorflow::gtl::optional;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::Split;
-using ::tensorflow::str_util::SplitAndParseAsInts;
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrJoin;
const double kF16max = 65504;
@@ -47,7 +49,7 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(StringPiece str, const HloModuleConfig& config)
+ explicit HloParser(absl::string_view str, const HloModuleConfig& config)
: lexer_(str), config_(config) {}
// Runs the parser. Returns false if an error occurred.
@@ -57,14 +59,28 @@ class HloParser {
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
- string GetError() const { return Join(error_, "\n"); }
+ string GetError() const { return StrJoin(error_, "\n"); }
// Stand alone parsing utils for various aggregate data types.
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
+ // Stand-alone parsing utility for a single instruction worth of text.
+ Status ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name);
+
private:
+ // Locates an instruction with the given name in the instruction_pool_ or
+ // returns nullptr.
+ //
+ // If the missing_instruction_hook_ is registered and a "shape" is provided,
+ // the hook will be called and may satisfy the request for the given
+ // instruction. This is useful when we reify parameters as they're resolved;
+ // i.e. for ParseSingleInstruction.
+ std::pair<HloInstruction*, LocTy>* FindInstruction(
+ const string& name, const optional<Shape>& shape = nullopt);
+
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputations();
@@ -138,6 +154,7 @@ class HloParser {
kFusionKind,
kDistribution,
kDomain,
+ kPrecisionList,
};
struct AttrConfig {
@@ -203,6 +220,7 @@ class HloParser {
bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
+ bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -221,6 +239,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
+ bool ParsePrecision(PrecisionConfigProto::Precision* result);
bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -233,8 +252,8 @@ class HloParser {
bool CanBeParamListToShape();
// Logs the current parsing line and the given message. Always returns false.
- bool TokenError(StringPiece msg);
- bool Error(LocTy loc, StringPiece msg);
+ bool TokenError(absl::string_view msg);
+ bool Error(LocTy loc, absl::string_view msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
@@ -265,24 +284,55 @@ class HloParser {
std::vector<std::unique_ptr<HloComputation>> computations_;
const HloModuleConfig config_;
std::vector<string> error_;
+
+ // Function that gets invoked when we try to resolve an instruction
+ // instruction_pool_ but fail to do so.
+ std::function<std::pair<HloInstruction*, LocTy>*(string,
+ const optional<Shape>&)>
+ missing_instruction_hook_;
};
-bool HloParser::Error(LocTy loc, StringPiece msg) {
+bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
+ for (const auto& split : absl::StrSplit(s, delim)) {
+ int64 val;
+ if (!absl::SimpleAtoi(split, &val)) {
+ return false;
+ }
+ out->push_back(val);
+ }
+ return true;
+}
+
+// Creates replica groups from the provided nested array. groups[i] represents
+// the replica ids for group 'i'.
+std::vector<ReplicaGroup> CreateReplicaGroups(
+ tensorflow::gtl::ArraySlice<std::vector<int64>> groups) {
+ std::vector<ReplicaGroup> replica_groups;
+ absl::c_transform(groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
+ return replica_groups;
+}
+
+bool HloParser::Error(LocTy loc, absl::string_view msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
const unsigned col = line_col.second;
std::vector<string> error_lines;
error_lines.push_back(
StrCat("was parsing ", line, ":", col, ": error: ", msg));
- error_lines.push_back(std::string(lexer_.GetLine(loc)));
+ error_lines.emplace_back(lexer_.GetLine(loc));
error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
- error_.push_back(Join(error_lines, "\n"));
+ error_.push_back(StrJoin(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
-bool HloParser::TokenError(StringPiece msg) {
+bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
@@ -291,6 +341,17 @@ bool HloParser::Run() {
return ParseHloModule();
}
+std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
+ const string& name, const optional<Shape>& shape) {
+ std::pair<HloInstruction*, LocTy>* instr =
+ tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ // Potentially call the missing instruction hook.
+ if (instr == nullptr && missing_instruction_hook_ != nullptr) {
+ return missing_instruction_hook_(name, shape);
+ }
+ return instr;
+}
+
// ::= 'HloModule' name computations
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
@@ -304,7 +365,7 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = MakeUnique<HloModule>(name, config_);
+ module_ = absl::make_unique<HloModule>(name, config_);
return ParseComputations();
}
@@ -357,7 +418,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = MakeUnique<HloComputation::Builder>(name);
+ auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -370,8 +431,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node =
- tensorflow::gtl::FindOrNull(instruction_pool_, root_name);
+ std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
// This means some instruction was marked as ROOT but we didn't find it in the
// pool, which should not happen.
if (!root_name.empty() && root_node == nullptr) {
@@ -469,6 +529,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -498,11 +562,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kIota: {
+ optional<tensorflow::int64> iota_dimension;
+ attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
+ &iota_dimension};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateIota(shape));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
@@ -597,31 +665,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
+ optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<string> barrier;
optional<int64> all_reduce_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
- attrs["replica_group_ids"] = {
- /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
+ attrs["replica_groups"] = {/*required=*/false,
+ AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
&all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- if (replica_group_ids) {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, *replica_group_ids,
- barrier ? *barrier : "", all_reduce_id));
- } else {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, {}, barrier ? *barrier : "",
- all_reduce_id));
+ std::vector<ReplicaGroup> replica_groups;
+ if (tmp_groups) {
+ replica_groups = CreateReplicaGroups(*tmp_groups);
}
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, replica_groups,
+ barrier ? *barrier : "", all_reduce_id));
break;
}
case HloOpcode::kAllToAll: {
@@ -629,21 +695,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> barrier;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
- attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
- c_transform(*tmp_groups, std::back_inserter(replica_groups),
- [](const std::vector<int64>& ids) {
- ReplicaGroup group;
- *group.mutable_replica_ids() = {ids.begin(), ids.end()};
- return group;
- });
+ replica_groups = CreateReplicaGroups(*tmp_groups);
}
- instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
- shape, operands, replica_groups, barrier ? *barrier : ""));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateAllToAll(shape, operands, replica_groups));
+ break;
+ }
+ case HloOpcode::kCollectivePermute: {
+ optional<std::vector<std::vector<int64>>> source_targets;
+ attrs["source_target_pairs"] = {
+ /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ std::vector<std::pair<int64, int64>> pairs(source_targets->size());
+ for (int i = 0; i < pairs.size(); i++) {
+ if ((*source_targets)[i].size() != 2) {
+ return TokenError(
+ "expects 'source_target_pairs=' to be a list of pairs");
+ }
+ pairs[i].first = (*source_targets)[i][0];
+ pairs[i].second = (*source_targets)[i][1];
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCollectivePermute(shape, operands[0], pairs));
break;
}
case HloOpcode::kReshape: {
@@ -1177,20 +1258,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
break;
}
- case HloOpcode::kHostCompute: {
- optional<string> channel_name;
- optional<tensorflow::int64> cost_estimate_ns;
- attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
- &channel_name};
- attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
- &cost_estimate_ns};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
- shape, operands, *channel_name, *cost_estimate_ns));
- break;
- }
case HloOpcode::kDot: {
optional<std::vector<tensorflow::int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
@@ -1346,6 +1413,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
+ if (operand_precision) {
+ PrecisionConfigProto precision_config;
+ *precision_config.mutable_operand_precision() = {operand_precision->begin(),
+ operand_precision->end()};
+ instruction->set_precision_config(precision_config);
+ }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
@@ -1509,14 +1582,14 @@ bool HloParser::ParseDomain(DomainData* domain) {
return false;
}
if (*kind == ShardingMetadata::KindName()) {
- auto entry_sharding_ptr = MakeUnique<HloSharding>(
+ auto entry_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*entry_sharding).ValueOrDie());
- auto exit_sharding_ptr = MakeUnique<HloSharding>(
+ auto exit_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*exit_sharding).ValueOrDie());
domain->entry_metadata =
- MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
domain->exit_metadata =
- MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
} else {
return TokenError(StrCat("unsupported domain kind: ", *kind));
}
@@ -1536,11 +1609,9 @@ bool HloParser::ParseInstructionNames(
if (!ParseName(&name)) {
return Error(loc, "expects a instruction name");
}
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
if (!instr) {
- return TokenError(
- Printf("instruction '%s' is not defined", name.c_str()));
+ return TokenError(StrFormat("instruction '%s' is not defined", name));
}
instructions->push_back(instr->first);
} while (EatIfPresent(TokKind::kComma));
@@ -1769,10 +1840,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
std::vector<tensorflow::int64> elems_seen_until_dim(
elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim);
return StrCat("[",
- Join(elems_seen_until_dim, ",",
- [](string* out, const tensorflow::int64& num_elems) {
- StrAppend(out, num_elems - 1);
- }),
+ StrJoin(elems_seen_until_dim, ",",
+ [](string* out, const tensorflow::int64& num_elems) {
+ StrAppend(out, num_elems - 1);
+ }),
"]");
};
do {
@@ -1782,17 +1853,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
case TokKind::kLbrace: {
nest_level++;
if (nest_level > rank) {
- return TokenError(Printf(
- "expects nested array in rank %lld, but sees larger", rank));
+ return TokenError(absl::StrFormat(
+ "expects nested array in rank %d, but sees larger", rank));
}
if (nest_level > 1) {
elems_seen_per_dim[nest_level - 2]++;
if (elems_seen_per_dim[nest_level - 2] >
shape.dimensions(nest_level - 2)) {
- return TokenError(Printf(
- "expects %lld elements in the %sth element, but sees more",
+ return TokenError(absl::StrFormat(
+ "expects %d elements in the %sth element, but sees more",
shape.dimensions(nest_level - 2),
- get_index_str(nest_level - 2).c_str()));
+ get_index_str(nest_level - 2)));
}
}
lexer_.Lex();
@@ -1801,9 +1872,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
case TokKind::kRbrace: {
nest_level--;
if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
- return TokenError(Printf(
- "expects %lld elements in the %sth element, but sees %lld",
- shape.dimensions(nest_level), get_index_str(nest_level).c_str(),
+ return TokenError(absl::StrFormat(
+ "expects %d elements in the %sth element, but sees %d",
+ shape.dimensions(nest_level), get_index_str(nest_level),
elems_seen_per_dim[nest_level]));
}
elems_seen_per_dim[nest_level] = 0;
@@ -1824,15 +1895,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
if (rank > 0) {
if (nest_level != rank) {
return TokenError(
- Printf("expects nested array in rank %lld, but sees %lld", rank,
- nest_level));
+ absl::StrFormat("expects nested array in rank %d, but sees %d",
+ rank, nest_level));
}
elems_seen_per_dim[rank - 1]++;
if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
- return TokenError(
- Printf("expects %lld elements on the minor-most dimension, but "
- "sees more",
- shape.dimensions(rank - 1)));
+ return TokenError(absl::StrFormat(
+ "expects %d elements on the minor-most dimension, but "
+ "sees more",
+ shape.dimensions(rank - 1)));
}
}
if (lexer_.GetKind() == TokKind::kw_true ||
@@ -1925,7 +1996,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = MakeUnique<Literal>(shape);
+ *literal = absl::make_unique<Literal>(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -1959,7 +2030,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return Error(
index_loc,
StrCat("invalid multi-dimension index for shape with rank ", rank,
- ": [", Join(index, ", "), "]"));
+ ": [", StrJoin(index, ", "), "]"));
}
}
if (!ParseToken(TokKind::kColon,
@@ -2020,6 +2091,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
// ::= operand (, operand)*
// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
+ CHECK(operands != nullptr);
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
@@ -2030,9 +2102,10 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
do {
LocTy loc = lexer_.GetLoc();
string name;
+ optional<Shape> shape;
if (CanBeShape()) {
- Shape shape;
- if (!ParseShape(&shape)) {
+ shape.emplace();
+ if (!ParseShape(&shape.value())) {
return false;
}
}
@@ -2040,8 +2113,8 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
return false;
}
std::pair<HloInstruction*, LocTy>* instruction =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
- if (!instruction) {
+ FindInstruction(name, shape);
+ if (instruction == nullptr) {
return Error(loc, StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction->first);
@@ -2052,6 +2125,7 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
+ CHECK(operands != nullptr);
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(operands)) {
return false;
@@ -2085,8 +2159,8 @@ bool HloParser::ParseSubAttributes(
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
- return Error(loc, Printf("sub-attribute %s is expected but not seen",
- attr_it.first.c_str()));
+ return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
+ attr_it.first));
}
}
return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
@@ -2106,8 +2180,8 @@ bool HloParser::ParseAttributes(
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
- return Error(loc, Printf("attribute %s is expected but not seen",
- attr_it.first.c_str()));
+ return Error(loc, StrFormat("attribute %s is expected but not seen",
+ attr_it.first));
}
}
return true;
@@ -2123,7 +2197,7 @@ bool HloParser::ParseAttributeHelper(
}
VLOG(1) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
- return Error(loc, Printf("attribute %s already exists", name.c_str()));
+ return Error(loc, StrFormat("attribute %s already exists", name));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
@@ -2133,13 +2207,13 @@ bool HloParser::ParseAttributeHelper(
} else {
allowed_attrs = StrCat(
"Allowed attributes: ",
- Join(attrs, ", ",
- [&](string* out, const std::pair<string, AttrConfig>& kv) {
- StrAppend(out, kv.first);
- }));
+ StrJoin(attrs, ", ",
+ [&](string* out, const std::pair<string, AttrConfig>& kv) {
+ StrAppend(out, kv.first);
+ }));
}
- return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(),
- allowed_attrs.c_str()));
+ return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
+ allowed_attrs));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
@@ -2321,10 +2395,20 @@ bool HloParser::ParseAttributeHelper(
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
+ case AttrTy::kPrecisionList: {
+ std::vector<PrecisionConfigProto::Precision> result;
+ if (!ParsePrecisionList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
- return Error(loc, Printf("error parsing attribute %s", name.c_str()));
+ return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}
@@ -2439,20 +2523,24 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
string str = lexer_.GetStrVal();
- // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
+ // The str is expected to have 3 items, lhs, rhs, out, and it must look like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
- // So we replace the "->" with "_" and then split on "_".
- str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
- /*newsub=*/"_",
- /*replace_all=*/false);
- std::vector<string> lhs_rhs_out = Split(str, "_");
- if (lhs_rhs_out.size() != 3) {
+ std::vector<string> split1 = absl::StrSplit(str, "_");
+ if (split1.size() != 2) {
+ LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
+ << str;
+ }
+ std::vector<string> split2 = absl::StrSplit(split1[1], "->");
+ if (split2.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
+ absl::string_view lhs = split1[0];
+ absl::string_view rhs = split2[0];
+ absl::string_view out = split2[1];
- const tensorflow::int64 rank = lhs_rhs_out[0].length();
- if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
+ const tensorflow::int64 rank = lhs.length();
+ if (rank != rhs.length() || rank != out.length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
}
@@ -2467,8 +2555,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
// lhs
{
- const string& lhs = lhs_rhs_out[0];
- if (!is_unique(lhs)) {
+ if (!is_unique(string(lhs))) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
@@ -2485,14 +2572,13 @@ bool HloParser::ParseConvolutionDimensionNumbers(
dnums->set_input_spatial_dimensions(c - '0', i);
} else {
return TokenError(
- Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
+ StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
}
}
}
// rhs
{
- const string& rhs = lhs_rhs_out[1];
- if (!is_unique(rhs)) {
+ if (!is_unique(string(rhs))) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
@@ -2509,14 +2595,13 @@ bool HloParser::ParseConvolutionDimensionNumbers(
dnums->set_kernel_spatial_dimensions(c - '0', i);
} else {
return TokenError(
- Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
+ StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
}
}
}
// output
{
- const string& out = lhs_rhs_out[2];
- if (!is_unique(out)) {
+ if (!is_unique(string(out))) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
@@ -2532,8 +2617,8 @@ bool HloParser::ParseConvolutionDimensionNumbers(
} else if (c < '0' + rank && c >= '0') {
dnums->set_output_spatial_dimensions(c - '0', i);
} else {
- return TokenError(
- Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
+ return TokenError(StrFormat(
+ "expects [0-%dbf] in output dimension numbers", rank - 1));
}
}
}
@@ -2579,9 +2664,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
}
const auto& range = ranges.back();
if (range.size() != 2 && range.size() != 3) {
- return Error(loc, Printf("expects [start:limit:step] or [start:limit], "
- "but sees %ld elements.",
- range.size()));
+ return Error(loc,
+ StrFormat("expects [start:limit:step] or [start:limit], "
+ "but sees %d elements.",
+ range.size()));
}
} while (EatIfPresent(TokKind::kComma));
@@ -2593,6 +2679,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
+// precisionlist ::= start precision_elements end
+// precision_elements
+// ::= /*empty*/
+// ::= precision_val (delim precision_val)*
+bool HloParser::ParsePrecisionList(
+ std::vector<PrecisionConfigProto::Precision>* result) {
+ auto parse_and_add_item = [&]() {
+ PrecisionConfigProto::Precision item;
+ if (!ParsePrecision(&item)) {
+ return false;
+ }
+ result->push_back(item);
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2749,14 +2853,13 @@ bool HloParser::ParseDxD(const string& name,
std::vector<tensorflow::int64>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
- return Error(loc,
- Printf("sub-attribute '%s=' already exists", name.c_str()));
+ return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
tensorflow::int64 number;
if (!ParseInt64(&number)) {
- return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
+ return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
}
result->push_back(number);
return true;
@@ -2764,9 +2867,8 @@ bool HloParser::ParseDxD(const string& name,
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
string str = lexer_.GetStrVal();
- if (!SplitAndParseAsInts(str, 'x', result)) {
- return Error(loc,
- Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
+ if (!SplitToInt64s(str, 'x', result)) {
+ return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
}
lexer_.Lex();
return true;
@@ -2784,10 +2886,9 @@ bool HloParser::ParseWindowPad(
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (int i = 0; i < padding_str.size(); i++) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> low_high;
- if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
"expects padding_low and padding_high separated by '_'");
@@ -2808,10 +2909,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
}
LocTy loc = lexer_.GetLoc();
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (const auto& padding_dim_str : padding_str) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> padding_dim;
- if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
"expects padding config pattern like 'low_high_interior' or "
@@ -2863,9 +2963,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
string val = lexer_.GetStrVal();
auto status_or_result = StringToHloOpcode(val);
if (!status_or_result.ok()) {
- return TokenError(
- Printf("expects opcode but sees: %s, error: %s", val.c_str(),
- status_or_result.status().error_message().c_str()));
+ return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -2879,7 +2978,7 @@ bool HloParser::ParseFftType(FftType* result) {
}
string val = lexer_.GetStrVal();
if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
- return TokenError(Printf("expects fft type but sees: %s", val.c_str()));
+ return TokenError(StrFormat("expects fft type but sees: %s", val));
}
lexer_.Lex();
return true;
@@ -2893,9 +2992,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
string val = lexer_.GetStrVal();
auto status_or_result = StringToFusionKind(val);
if (!status_or_result.ok()) {
- return TokenError(
- Printf("expects fusion kind but sees: %s, error: %s", val.c_str(),
- status_or_result.status().error_message().c_str()));
+ return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
+ val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -2911,8 +3010,25 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
auto status_or_result = StringToRandomDistribution(val);
if (!status_or_result.ok()) {
return TokenError(
- Printf("expects random distribution but sees: %s, error: %s",
- val.c_str(), status_or_result.status().error_message().c_str()));
+ StrFormat("expects random distribution but sees: %s, error: %s", val,
+ status_or_result.status().error_message()));
+ }
+ *result = status_or_result.ValueOrDie();
+ lexer_.Lex();
+ return true;
+}
+
+bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) {
+ VLOG(1) << "ParsePrecision";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects random distribution");
+ }
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToPrecision(val);
+ if (!status_or_result.ok()) {
+ return TokenError(StrFormat("expects precision but sees: %s, error: %s",
+ val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -3006,7 +3122,7 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
lexer_.Lex();
OpSharding op_sharding;
if (!ParseSharding(&op_sharding)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after sharding");
@@ -3018,7 +3134,7 @@ StatusOr<Window> HloParser::ParseWindowOnly() {
lexer_.Lex();
Window window;
if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after window");
@@ -3031,7 +3147,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
lexer_.Lex();
ConvolutionDimensionNumbers dnums;
if (!ParseConvolutionDimensionNumbers(&dnums)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
@@ -3040,37 +3156,83 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
return dnums;
}
+Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name) {
+ TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+
+ // The missing instruction hook we register creates the shaped instruction on
+ // the fly as a parameter and returns it.
+ int64 parameter_count = 0;
+ missing_instruction_hook_ =
+ [this, builder, &parameter_count](
+ string name,
+ const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ StrCat("Operand ", name,
+ " had no shape in HLO text; cannot create parameter for "
+ "single-instruction module."));
+ return nullptr;
+ }
+ HloInstruction* parameter = builder->AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, *shape, name));
+ instruction_pool_[name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ };
+
+ // Prime the lexer.
+ lexer_.Lex();
+
+ // Parse the instruction with the registered hook.
+ if (!ParseInstruction(builder, root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config) {
+ absl::string_view str, const HloModuleConfig& config) {
HloParser parser(str, config);
if (!parser.Run()) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
return parser.ConsumeHloModule();
}
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
HloModuleConfig config;
return ParseHloString(str, config);
}
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
+ absl::string_view str, absl::string_view name) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ auto builder = absl::make_unique<HloComputation::Builder>(string(name));
+ string root_name;
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
+ std::unique_ptr<HloComputation> computation = builder->Build();
+ auto module = absl::make_unique<HloModule>(string(name), config);
+ module->AddEntryComputation(std::move(computation));
+ return std::move(module);
+}
+
+StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseShardingOnly();
}
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+StatusOr<Window> ParseWindow(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str) {
+ absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseConvolutionDimensionNumbersOnly();
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 3f3a51215e..0c64b50481 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -16,7 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
@@ -32,27 +33,31 @@ namespace xla {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config);
+ absl::string_view str, const HloModuleConfig& config);
+
+// Parses the text for a single HLO operation into an HLO module with a function
+// that runs that operation (with the same parameters) as its entry computation.
+StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
+ absl::string_view str, absl::string_view name = "single_op");
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with default config.
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
// Parses the result of window_util::ToString(const Window&).
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+StatusOr<Window> ParseWindow(absl::string_view str);
// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
// "b0f_0io->b0f".
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str);
+ absl::string_view str);
// ParseHloString sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 0d7919346b..ba07ec432e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -16,17 +16,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <string>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
-
namespace {
-using ::tensorflow::StringPiece;
+namespace op = ::xla::testing::opcode_matchers;
+using absl::string_view;
struct TestData {
string test_name;
@@ -1049,7 +1051,7 @@ add {
ENTRY CRS {
input = f32[8]{0} parameter(0)
- ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add
+ ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add
}
)"
@@ -1067,7 +1069,7 @@ add {
ENTRY CrossReplicaSumWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
- ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add
}
)"
@@ -1091,7 +1093,19 @@ R"(HloModule AllToAllWithSubgroups
ENTRY AllToAllWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
- ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc"
+ ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}
+}
+
+)"
+},
+// collective-permute
+{
+"CollectivePermute",
+R"(HloModule CollectivePermute
+
+ENTRY CollectivePermute {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
}
)"
@@ -1102,7 +1116,7 @@ ENTRY AllToAllWithSubgroups {
R"(HloModule iota
ENTRY Iota {
- ROOT iota = f32[100]{0} iota()
+ ROOT iota = f32[100]{0} iota(), iota_dimension=0
}
)"
@@ -1125,8 +1139,8 @@ ENTRY Computation {
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected))
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
@@ -1390,15 +1404,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
)";
- ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=00_01_10", suffix))
- .status()
- .error_message(),
- "expects dim labels pattern");
+ ExpectHasSubstr(
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
ExpectHasSubstr(
- ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=010_1100->010", suffix))
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
"must have the same rank");
@@ -1722,5 +1735,26 @@ ENTRY nontuple_infeed {
"infeed must have a non-empty tuple shape");
}
+TEST(HloParserSingleOpTest, SingleOp) {
+ const string text =
+ "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
+ "f32[2,4]{1,0} %x)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
+ const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ ASSERT_TRUE(!module.status().ok());
+ LOG(INFO) << "Status: " << module.status();
+ EXPECT_THAT(
+ module.status().ToString(),
+ ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index 0cddf8fb8f..f1ad0f9b01 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -29,7 +29,7 @@ namespace xla {
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
- virtual tensorflow::StringPiece name() const = 0;
+ virtual absl::string_view name() const = 0;
// Run the pass on the given HLO module. Return whether it modified the
// module.
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index d8f1ab916b..6e4ed0de62 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,22 +17,23 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+
+using absl::StrAppend;
+using absl::StrCat;
+
void DumpModuleGraph(const HloModule& module, const string& message) {
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
@@ -48,9 +49,9 @@ void DumpModuleProto(const HloModule& module, const string& dump_to,
tensorflow::mutex_lock lock(mu);
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
- const string mod_name = SanitizeFileName(tensorflow::strings::Printf(
- "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number,
- pipeline_name.c_str(), pass_name.c_str()));
+ const string mod_name = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, pipeline_name, pass_name));
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
dump_to, mod_name));
@@ -68,7 +69,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
repeated_field.end());
if (!disabled_passes.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << tensorflow::str_util::Join(disabled_passes, ", ");
+ << absl::StrJoin(disabled_passes, ", ");
}
auto run_invariant_checkers = [this,
@@ -90,7 +91,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
return Status::OK();
};
- string prefix = std::string(name()) + ": pipeline start";
+ string prefix = StrCat(name(), ": pipeline start");
bool changed = false;
string message;
TF_RETURN_IF_ERROR(
@@ -98,12 +99,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
const string xla_dump_per_pass_hlo_proto_to =
module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
- std::string(name()), "pipeline_start");
+ DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
+ "pipeline_start");
}
for (auto& pass : passes_) {
- if (disabled_passes.count(std::string(pass->name())) > 0) {
+ if (disabled_passes.count(string(pass->name())) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;
@@ -120,8 +121,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
- std::string(name()), std::string(pass->name()));
+ DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
+ string(pass->name()));
}
changed |= changed_this_pass;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a42d7e59fe..1d41a4dac1 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -34,7 +34,7 @@ namespace xla {
class HloPassPipeline : public HloPassInterface {
public:
explicit HloPassPipeline(const string& name) : name_(name) {}
- tensorflow::StringPiece name() const override { return name_; }
+ absl::string_view name() const override { return name_; }
// Add a pass to the pipeline. It should be called with the arguments for the
// pass constructor:
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
index b9cca13870..c3cacd7ce6 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index cf0be30c7a..569d2e5d2d 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -20,6 +20,10 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@@ -37,17 +41,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
@@ -88,7 +88,7 @@ bool CanBeRematerialized(
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
-using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
+using BufferIdList = absl::InlinedVector<BufferId, 3>;
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
@@ -123,7 +123,7 @@ struct Item {
int64 position;
};
-using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
+using ItemList = absl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements.
@@ -206,11 +206,10 @@ class InstructionList {
Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
- << tensorflow::str_util::Join(before_instructions, ", ",
- [](string* out, Item* item) {
- tensorflow::strings::StrAppend(
- out, item->instruction->name());
- })
+ << absl::StrJoin(before_instructions, ", ",
+ [](string* out, Item* item) {
+ absl::StrAppend(out, item->instruction->name());
+ })
<< "}";
// Find the minimal position number of any instruction in
@@ -393,10 +392,9 @@ class MemoryUsageTracker {
int64 unfinished_user_count;
string ToString() const {
- return tensorflow::strings::StrCat(
- "Buffer ", id, " (defined by ",
- defining_instruction->instruction->name(), ", size ", size,
- " bytes)");
+ return absl::StrCat("Buffer ", id, " (defined by ",
+ defining_instruction->instruction->name(), ", size ",
+ size, " bytes)");
}
};
@@ -740,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
}
string MemoryUsageTracker::ToString() const {
- string output = tensorflow::strings::StrCat("MemoryUsageTracker for ",
- computation_->name(), "\n");
- tensorflow::strings::StrAppend(
- &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
- memory_usage(), " bytes)");
+ string output =
+ absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
+ absl::StrAppend(&output,
+ "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
+ memory_usage(), " bytes)");
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* instruction = item->instruction;
string inprogress = item == in_progress_item_ ? " in-progress" : "";
string placed = item->placed ? " placed" : "";
- tensorflow::strings::StrAppend(&output, " ", instruction->name(),
- inprogress, placed, "\n Defines:\n");
+ absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
+ "\n Defines:\n");
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_[buffer_id];
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
- tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
- ", ", buffer.unfinished_user_count,
- " unfinished uses\n");
+ absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
+ buffer.unfinished_user_count, " unfinished uses\n");
}
- tensorflow::strings::StrAppend(&output, " Uses:\n");
+ absl::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : item->buffers_used) {
- tensorflow::strings::StrAppend(&output, " ",
- buffers_[buffer_id].ToString(), "\n");
+ absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
}
}
return output;
@@ -780,10 +776,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(defined_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique defined buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
for (const Buffer& buffer : buffers_) {
@@ -803,10 +798,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(used_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique used buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
used_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
}
for (const Buffer& buffer : buffers_) {
@@ -1209,6 +1203,49 @@ StatusOr<bool> HloRematerialization::Run(
VLOG(1) << "HloRematerialization() with memory limit of "
<< HumanReadableNumBytes(memory_limit_bytes);
+ XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+
+ // Create initial sequence of HLO instructions.
+ TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
+ *module,
+ [this](const BufferValue& buffer) {
+ return size_function_(buffer.shape());
+ },
+ scheduler_algorithm_));
+ if (copy_insertion) {
+ // We run a separate pass of copy elision here because the sequential
+ // ordering from the HLO schedule allows for more copies to be eliminated.
+ // TODO(b/80249101): Instead of a separate copy elision pass, use the
+ // ordering from the HLO schedule directly for copy insertion.
+
+ // First create a copy of the schedule which contains HloInstruction unique
+ // ids instead of HloInstruction*. This is necessary for updating the
+ // schedule below.
+ // TODO(b/113175018): Remove this when the HLO schedule is self-contained
+ // and can update itself.
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(*sequence);
+
+ SequentialHloOrdering ordering(module, *sequence);
+ TF_RETURN_IF_ERROR(
+ copy_insertion->RemoveUnnecessaryCopies(ordering, module));
+
+ // RemoveUnnecessaryCopies only considers interference when determining
+ // whether it is legal to remove a copy. However, copies in the graph may be
+ // necessary for other reason such as preventing a constant from being live
+ // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
+ // TODO(b/80249101): Break copy insertion into several passes and run each
+ // one once in the regular HLO pipeline.
+ TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
+
+ // The passes above can add and remove copies, update the schedule to
+ // account for these transformations. Newly added instructions will be
+ // placed ASAP in the schedule.
+ TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence));
+
+ TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
+ SequentialHloOrdering(module, *sequence), module));
+ }
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
@@ -1230,24 +1267,6 @@ StatusOr<bool> HloRematerialization::Run(
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
- XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
- *module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(module, *sequence);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
- }
-
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@@ -1334,12 +1353,11 @@ StatusOr<bool> HloRematerialization::Run(
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
if (current_peak_memory > memory_limit_bytes) {
- LOG(WARNING) << tensorflow::strings::Printf(
- "Can't reduce memory use below %s (%lld bytes) by rematerialization; "
- "only reduced to %s (%lld bytes)",
- HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
- HumanReadableNumBytes(current_peak_memory).c_str(),
- current_peak_memory);
+ LOG(WARNING) << absl::StrFormat(
+ "Can't reduce memory use below %s (%d bytes) by rematerialization; "
+ "only reduced to %s (%d bytes)",
+ HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+ HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index b2725e2918..7bd8a4a544 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -32,7 +32,7 @@ limitations under the License.
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
-HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
+HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
@@ -233,7 +233,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
int64 device = device_assignment(i, 0);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device));
- streams.push_back(MakeUnique<se::Stream>(executor));
+ streams.push_back(absl::make_unique<se::Stream>(executor));
streams.back()->Init();
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
device, streams.back().get(), &device_assignment));
@@ -260,7 +260,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
num_threads += options.num_replicas;
}
if (num_threads > 0) {
- pool = MakeUnique<tensorflow::thread::ThreadPool>(
+ pool = absl::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "infeed_outfeed",
/*num_threads=*/num_threads);
}
@@ -291,7 +291,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = MakeUnique<Literal>();
+ auto literal = absl::make_unique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, options.outfeed_shape, literal.get()));
if (options.outfeed_values != nullptr) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 65537f07f5..cfc519063e 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -87,8 +87,7 @@ class HloRunner {
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
- const tensorflow::StringPiece hlo_string,
- const DebugOptions& debug_options);
+ const absl::string_view hlo_string, const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 27cc5361cd..0fc3b268c0 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include <map>
+#include <queue>
#include <utility>
#include <vector>
@@ -28,16 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Class implementing a list scheduler of HLO instructions which produces a
// sequence which minimizes memory usage by preferring to schedule the node that
// frees bigger buffer and defines smaller outputs.
@@ -582,4 +581,187 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
size_function, nullptr, empty_map);
}
+tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) {
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence;
+ for (const auto& computation_sequence : sequence) {
+ for (const HloInstruction* instruction : computation_sequence.second) {
+ id_sequence[computation_sequence.first].push_back(
+ instruction->unique_id());
+ }
+ }
+ return id_sequence;
+}
+
+Status UpdateSchedule(
+ const HloModule& module,
+ const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
+ id_sequence,
+ SequentialHloOrdering::HloModuleSequence* sequence) {
+ // Map from unique ID to HloInstruction pointer for instructions in the
+ // module.
+ tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ // Set of all HloInstructions in the schedule.
+ tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ std::vector<HloComputation*> nonfusion_computations =
+ module.MakeNonfusionComputations();
+ for (const HloComputation* computation : nonfusion_computations) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ TF_RET_CHECK(
+ id_to_instruction.insert({instruction->unique_id(), instruction})
+ .second);
+ }
+ for (int id : id_sequence.at(computation)) {
+ ids_in_schedule.insert(id);
+ }
+ }
+
+ // Map from HloInstruction X to newly added instructions (instruction is in
+ // module, but not in schedule) which use X. If an instruction is not in the
+ // map, then it has no users which are newly added instructions.
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::vector<const HloInstruction*>>
+ new_instruction_uses;
+
+ // For each newly added instruction, this is the count of the instruction's
+ // operands that have not yet been scheduled. When this value reaches zero,
+ // then the instruction may be placed in the schedule.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int>
+ unscheduled_operand_count;
+ // For each computation, this is the set of newly added instructions which
+ // have no operands. These must be handled specially and are added to the
+ // beginning of the schedule.
+ tensorflow::gtl::FlatMap<const HloComputation*,
+ std::vector<const HloInstruction*>>
+ new_zero_operand_instructions;
+ for (const HloComputation* computation : nonfusion_computations) {
+ new_zero_operand_instructions[computation] = {};
+ for (const HloInstruction* instruction : computation->instructions()) {
+ if (ids_in_schedule.count(instruction->unique_id()) == 0) {
+ // This is a newly added instruction which is not in the schedule.
+ for (const HloInstruction* operand : instruction->operands()) {
+ new_instruction_uses[operand].push_back(instruction);
+ }
+ if (instruction->operands().empty()) {
+ new_zero_operand_instructions[computation].push_back(instruction);
+ }
+ unscheduled_operand_count[instruction] = instruction->operand_count();
+ }
+ }
+ }
+
+ // Update the schedule with the newly added instructions, and remove any
+ // instructions no longer in the graph.
+ for (const HloComputation* computation : nonfusion_computations) {
+ std::vector<const HloInstruction*> old_computation_sequence =
+ std::move(sequence->at(computation));
+ sequence->at(computation).clear();
+
+ // Create a worklist of newly added instructions which are ready to be added
+ // to the schedule. Initialize worklist with those that have zero operands.
+ std::queue<const HloInstruction*> worklist;
+ for (const HloInstruction* instruction :
+ new_zero_operand_instructions.at(computation)) {
+ worklist.push(instruction);
+ }
+
+ // Lambda which schedules all instructions on the worklist.
+ auto schedule_worklist = [&]() {
+ while (!worklist.empty()) {
+ const HloInstruction* instruction = worklist.front();
+ worklist.pop();
+ sequence->at(computation).push_back(instruction);
+ std::vector<const HloInstruction*>* new_users =
+ tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
+ if (new_users != nullptr) {
+ // This just-scheduled instruction has users which are newly added to
+ // the module. Update the number of unscheduled operands and push the
+ // newly added instruction to the worklist if it is ready to
+ // schedule.
+ for (const HloInstruction* new_user : *new_users) {
+ unscheduled_operand_count.at(new_user)--;
+ CHECK_GE(unscheduled_operand_count.at(new_user), 0);
+ if (unscheduled_operand_count.at(new_user) == 0) {
+ worklist.push(new_user);
+ }
+ }
+ }
+ }
+ };
+
+ schedule_worklist();
+ for (int id : id_sequence.at(computation)) {
+ auto it = id_to_instruction.find(id);
+ if (it == id_to_instruction.end()) {
+ // This instruction in the schedule is no longer in the module.
+ continue;
+ }
+ const HloInstruction* instruction = it->second;
+ worklist.push(instruction);
+ schedule_worklist();
+ }
+ }
+
+ TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence));
+ return Status::OK();
+}
+
+Status VerifySchedule(
+ const HloModule& module,
+ const SequentialHloOrdering::HloModuleSequence& sequence) {
+ VLOG(2) << "VerifySchedule()";
+ XLA_VLOG_LINES(2, module.ToString());
+ VLOG(2) << sequence;
+
+ // Verify the set of computations in the sequence is exactly the set of
+ // computations in the module.
+ std::vector<HloComputation*> nonfusion_computations =
+ module.MakeNonfusionComputations();
+ TF_RET_CHECK(nonfusion_computations.size() == sequence.size());
+ tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module(
+ module.computations().begin(), module.computations().end());
+ for (const auto& computation_sequence : sequence) {
+ TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1);
+ }
+
+ // For each computation verify the set of instructions is the same and that
+ // each dependency and control edge is honored.
+ for (const HloComputation* computation : nonfusion_computations) {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ int pos = 0;
+ for (const HloInstruction* instruction : sequence.at(computation)) {
+ TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
+ << "Instruction " << instruction->name()
+ << " appears more than once in the schedule";
+ pos++;
+ }
+
+ TF_RET_CHECK(instruction_position.size() ==
+ computation->instruction_count());
+ for (const HloInstruction* instruction : computation->instructions()) {
+ TF_RET_CHECK(instruction_position.count(instruction) == 1)
+ << "Instruction " << instruction->name() << " is not in schedule";
+ }
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ for (const HloInstruction* operand : instruction->operands()) {
+ TF_RET_CHECK(instruction_position.at(operand) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its operand " << operand->name();
+ }
+
+ for (const HloInstruction* pred : instruction->control_predecessors()) {
+ TF_RET_CHECK(instruction_position.at(pred) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its control predecessor "
+ << pred->name();
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index 2b33ccc8bf..d06b8d9a5c 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -85,6 +85,43 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// Transforms the given schedule such that it is (again) a valid schedule for
+// the module. This is used to update a schedule after the HLO module has been
+// transformed in some way. In general, the only transformations to the module
+// for which a schedule can be updated is the addition or removal of
+// instructions to/from the module. Updating the schedule after new dependencies
+// between existing instructions in the module is not supported and may result
+// in an error status returned.
+//
+// Instructions in the module which also exist in the given schedule will remain
+// in the same order in the updated schedule. Instructions which exist in the
+// module but not in the given schedule will be placed as early as possible in
+// the updated schedule.
+//
+// 'id_sequence' is a mirror of the given schedule 'sequence' but with
+// HloInstruction ids rather than HloInstruction pointers. This should be
+// constructed using ComputeIdSchedule below after the schedule is constructed
+// but before the HLO module is transformed.
+Status UpdateSchedule(
+ const HloModule& module,
+ const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
+ id_sequence,
+ SequentialHloOrdering::HloModuleSequence* sequence);
+
+// Constructs a copy of the given schedule but with HloInstruction unique ids
+// rather than HloInstruction pointers. This is necessary for updating a
+// schedule as HloInstruction points in the schedule may become invalid if
+// instructions are removed from the module. Used by UpdateSchedule above..
+// TODO(b/113175018): Remove this function when HLO schedule is its own class.
+tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence);
+
+// Verifies that the given schedule is valid for the given module. Specifically,
+// the schedule contains exactly the instructions in the module and every
+// dependency in the module is satisfied in the schedule.
+Status VerifySchedule(const HloModule& module,
+ const SequentialHloOrdering::HloModuleSequence& sequence);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 9ec983c2bc..930801288a 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -244,9 +246,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
- // HeapSimulator accounts for subcomputations. The max mem doesn't change
- // because the while body isn't live during the peak.
- EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ // HeapSimulator accounts for subcomputations. The output buffer is aliased,
+ // so we don't double count.
+ EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
@@ -350,7 +352,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
auto module = CreateNewModule();
const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
- const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
// param != 0
// Needs 17 bytes
@@ -408,12 +409,259 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
- // HeapSimulator accounts for subcomputations
- EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
+ // HeapSimulator accounts for subcomputations. Cond is the largest one.
+ // The output buffer of the while is aliased.
+ EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
+TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) {
+ // Updating the schedule of an unchanged HLO module should not affect the
+ // schedule at all.
+ const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+ std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second;
+
+ EXPECT_EQ(entry_schedule.size(), 6);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(entry_schedule, sequence.begin()->second);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) {
+ // Add some additional instructions to a module and verify the schedule can be
+ // updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ HloComputation* entry = module->entry_computation();
+ const Shape shape = entry->root_instruction()->shape();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+ entry->set_root_instruction(sub);
+
+ auto in_schedule = [&](const HloInstruction* hlo) {
+ return std::find(sequence.at(entry).begin(), sequence.at(entry).end(),
+ hlo) != sequence.at(entry).end();
+ };
+
+ EXPECT_EQ(sequence.at(entry).size(), 6);
+ EXPECT_FALSE(in_schedule(constant));
+ EXPECT_FALSE(in_schedule(sub));
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 8);
+ EXPECT_TRUE(in_schedule(constant));
+ EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+ // Add and delete some instructions from a module and verify that the schedule
+ // can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ // Set the entry root to some expression containing just a parameter and a
+ // constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* new_root = entry->AddInstruction(
+ HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+ constant, entry->parameter_instruction(0)));
+ entry->set_root_instruction(new_root);
+
+ // DCE should remove everything but the parameters and the newly added code.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(entry).size(), 6);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 4);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) {
+ // Completely replace a module with an entirely new set of instructions and
+ // verify that the schedule can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+ a = f32[] constant(42.0)
+ b = f32[] constant(123.0)
+ ROOT sum = f32[] add(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ // Replace the entry computation with the negation of a constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ entry->set_root_instruction(new_root);
+
+ // DCE the old instructions.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(entry).size(), 3);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 2);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) {
+ // Create changes to more than one computation in an HLO module and verify
+ // that the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ const HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->operand(0);
+ HloComputation* body = xla_while->while_body();
+ HloComputation* cond = xla_while->while_condition();
+
+ // Negate the root of the cond.
+ cond->set_root_instruction(cond->AddInstruction(
+ HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kNot, cond->root_instruction())));
+
+ // Replace the body with a computation which just passes through its
+ // parameter.
+ body->set_root_instruction(body->parameter_instruction(0));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(body).size(), 7);
+ EXPECT_EQ(sequence.at(cond).size(), 4);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(body).size(), 1);
+ EXPECT_EQ(sequence.at(cond).size(), 5);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 0cba9ebbcb..980dae07ce 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
+using absl::StrJoin;
HloSharding HloSharding::AssignDevice(int64 device_id) {
return HloSharding(device_id);
@@ -71,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
CHECK(!sharding.IsTuple()) << sharding.ToString();
- int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ int64 leaf_count = RequiredLeaves(tuple_shape);
std::vector<HloSharding> flattened_list;
- flattened_list.reserve(leaf_count);
- for (int64 i = 0; i < leaf_count; ++i) {
- flattened_list.push_back(sharding);
- }
+ flattened_list.resize(leaf_count, sharding);
return HloSharding(flattened_list);
}
@@ -92,7 +90,7 @@ string HloSharding::ToString() const {
for (const HloSharding& element : tuple_elements_) {
parts.push_back(element.ToString());
}
- return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+ return StrCat("{", absl::StrJoin(parts, ", "), "}");
}
if (replicated_) {
@@ -101,8 +99,8 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]",
- Join(tile_assignment_, ","), "}");
+ return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","),
+ "]", StrJoin(tile_assignment_, ","), "}");
}
}
@@ -244,16 +242,16 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
-tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
+absl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
- tensorflow::gtl::optional<int64> unique_device;
+ absl::optional<int64> unique_device;
for (auto& tuple_sharding : tuple_elements_) {
auto device = tuple_sharding.UniqueDevice();
if (!device || (unique_device && *device != *unique_device)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
unique_device = device;
}
@@ -262,7 +260,7 @@ tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
int64 HloSharding::GetUniqueDevice() const {
@@ -439,14 +437,13 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape,
: sub_shape_tree.element(ShapeIndex({}));
}
-tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
- const {
+absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
if (!IsTuple()) {
return *this;
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
- return tensorflow::gtl::optional<HloSharding>();
+ return absl::nullopt;
}
}
return tuple_elements_.front();
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 894783e5d1..be51c3f55b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -151,7 +151,7 @@ class HloSharding {
// span a single device, the return value will be empty.
// In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
- tensorflow::gtl::optional<int64> UniqueDevice() const;
+ absl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
@@ -182,7 +182,7 @@ class HloSharding {
// be returned. If it is a tuple, and all the tuple elements are common, the
// common element will be returned. Otherwise the optional will contain no
// value.
- tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const;
+ absl::optional<HloSharding> ExtractSingleSharding() const;
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
@@ -260,9 +260,9 @@ class HloSharding {
bool maximal_;
bool tuple_;
Array<int64> tile_assignment_;
- // Only non-empty when tuple_ is true, but because empty tuples are allowed
- // may also be empty even then. This is a flattened list of all the leaf
- // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
+ // present for the root. This is a flattened list of all the leaf shardings in
+ // a tuple shape, by pre-order walk (ShapeTree iterator order).
std::vector<HloSharding> tuple_elements_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index a2c1d39d0d..6e9b96488c 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -23,6 +24,23 @@ namespace xla {
namespace {
+// AssignmentKind and kUnassignedDevice are used during tuple domain sharding
+// propagation in order to distinguish among three cases:
+// kUnassigned: no assignment has occurred
+// kAssigned: at least an assignment has occurred
+// kConflict: no assignment has occurred because of conflicting propagations,
+// which occurs when multiple users of an instruction have different
+// shardings.
+enum class AssignmentKind { kUnassigned, kAssigned, kConflict };
+
+// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
+// absence of sharding information for that particular sub-sharding during
+// sharding propagation. It is used to be able to express tuple shardings with
+// partial information. At the end of the propagation the sharding of
+// tuple-shaped instructions using kUnassignedDevice's is cleared.
+// TODO(b/112883246): Centralized enum of reserved devices.
+constexpr int64 kUnassignedDevice = -2;
+
struct PassThrough {
PassThrough(HloInstruction* user, HloInstruction* operand)
: user(user), operand(operand) {}
@@ -117,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
return Status::OK();
}
-std::unique_ptr<HloSharding> CloneShardingForDomain(
- const HloSharding& sharding) {
- auto single_sharding = sharding.ExtractSingleSharding();
+// For tuple shardings if every element have the same sharsing then we want to
+// treat them as single element sharsings to insert less domain separation as a
+// domain can prevent some optimizations and we want to minimize that from
+// happening.
+std::shared_ptr<const HloSharding> CloneShardingForDomain(
+ std::shared_ptr<const HloSharding> sharding) {
+ auto single_sharding = sharding->ExtractSingleSharding();
if (!single_sharding) {
- return MakeUnique<HloSharding>(sharding);
+ return sharding;
}
- return MakeUnique<HloSharding>(*single_sharding);
+ return std::make_shared<const HloSharding>(*single_sharding);
}
Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
@@ -142,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
return Status::OK();
}
-// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree.
-// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate()
-// sharding will be returned.
-ShapeTree<HloSharding> GetTupleSharding(HloInstruction* tuple) {
- if (tuple->has_sharding()) {
- return tuple->sharding().GetAsShapeTree(tuple->shape());
+// Return the ShapeTree<HloSharding> of the user argument. The user argument
+// is assumed to be a user of the instruction argument.
+// If user is a tuple instruction, return the tuple subsharding corresponding to
+// the operand matching the instruction argument, because that is the
+// subsharding corresponding to instruction.
+ShapeTree<HloSharding> GetShardingTreeFromUser(
+ const HloInstruction& instruction, const HloInstruction& user) {
+ if (user.opcode() == HloOpcode::kTuple) {
+ return user.sharding()
+ .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
+ .GetAsShapeTree(instruction.shape());
+ }
+ return user.sharding().GetAsShapeTree(user.shape());
+}
+
+// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
+// then no assignment is made. Therefore kUnassignedDevice is never propagated.
+// kConflict is returned if lhs is already assigned and rhs is assigned to a
+// different device.
+StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
+ const HloSharding& rhs) {
+ TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
+ if (rhs.UsesDevice(kUnassignedDevice)) {
+ return AssignmentKind::kUnassigned;
+ }
+ if (lhs->UsesDevice(kUnassignedDevice)) {
+ *lhs = rhs;
+ return AssignmentKind::kAssigned;
+ }
+ return lhs->UniqueDevice() != rhs.UniqueDevice()
+ ? AssignmentKind::kConflict
+ : AssignmentKind::kUnassigned;
+}
+
+// Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
+// In case of conflicting assignment AssignmentKind::kConflict is returned. In
+// this case lhs_tree is partially assigned, up to the conflicting leaf. It is
+// up to the caller to discard the partial assignment in case of conflict.
+StatusOr<AssignmentKind> AssignTreeSharding(
+ ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
+ const ShapeTree<HloSharding>& rhs_tree) {
+ AssignmentKind assigned = AssignmentKind::kUnassigned;
+ auto rhs_it = rhs_tree.begin();
+ for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
+ ++lhs_it, ++rhs_it) {
+ // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
+ if (rhs_tree.IsLeaf(rhs_it->first)) {
+ TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
+ TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
+ AssignLeafSharding(&lhs_it->second, rhs_it->second));
+ if (sub_assigned == AssignmentKind::kConflict) {
+ // In case of conflict we return conflict to the caller. At this point
+ // partial assignments to lhs_tree may have been made already. It is up
+ // to the caller to discard the partial assignment in case of conflict.
+ return AssignmentKind::kConflict;
+ } else if (sub_assigned == AssignmentKind::kAssigned) {
+ assigned = sub_assigned;
+ }
+ }
}
- return ShapeTree<HloSharding>(tuple->shape(), HloSharding::Replicate());
+ TF_RET_CHECK(rhs_it == rhs_tree.end());
+ return assigned;
}
-// Retrieves the sharding of operand, asked from a user instruction which is
-// within domain. If operand is a kDomain, it means that sharding argument is
-// the operand sharding, otherwise the operand's own sharding will be returned.
-const HloSharding* GetOperandSharding(const HloInstruction* operand,
+StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
const DomainMetadata::Domain& domain,
- const HloSharding& sharding) {
- // Here the user of operand is within the domain instruction set, and since it
- // is user of operand, we need to look into the enter_domains set. If this is
- // not a kDomain within the user domains set, then return the operand
- // sharding, if any.
- if (operand->opcode() != HloOpcode::kDomain ||
- domain.enter_domains.count(const_cast<HloInstruction*>(operand)) == 0) {
- return operand->has_sharding() ? &operand->sharding() : nullptr;
+ const HloSharding& domain_sharding) {
+ if (instruction->users().empty()) {
+ // No sharding from users, use domain_sharding, after checking
+ // compatibility.
+ TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) &&
+ ShapeUtil::GetLeafCount(instruction->shape()) ==
+ domain_sharding.tuple_elements().size());
+ instruction->set_sharding(domain_sharding);
+ return true;
+ }
+ AssignmentKind assigned = AssignmentKind::kUnassigned;
+ // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
+ // subshardings can result in a final sharding assignment containing
+ // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
+ // used by users that don't have a sharding.
+ // Non-tuple shardings are either assigned to a real sharding, or are not
+ // assigned at all. As such they will never get assigned to kUnassignedDevice.
+ // In any case, kUnassignedDevice is never propagated, from the implementation
+ // of AssignLeafSharding.
+ ShapeTree<HloSharding> sharding_tree(
+ instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
+ for (HloInstruction* user : instruction->users()) {
+ if (user->opcode() == HloOpcode::kDomain &&
+ domain.exit_domains.count(const_cast<HloInstruction*>(user)) > 0) {
+ // If a user is a domain and it is registered in the domain exits, then
+ // the instruction sharding is taken directly from the domain, and no
+ // further users need to be visited.
+ instruction->set_sharding(domain_sharding);
+ return true;
+ }
+ if (!user->has_sharding()) {
+ continue;
+ }
+ AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
+ ShapeTree<HloSharding> user_sharding_tree =
+ GetShardingTreeFromUser(*instruction, *user);
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // For tuple-shaped instructions collect individual tuple subshardings
+ // from the uses, and then combine them into the tuple sharding.
+ // If the user is a GTE its sharding concerns only the subtree of
+ // sharding_tree at index user->tuple_index, otherwise the whole
+ // sharding_tree is affected.
+ ShapeTree<HloSharding>::iterator sharding_tree_begin =
+ user->opcode() == HloOpcode::kGetTupleElement
+ ? sharding_tree.find({user->tuple_index()})
+ : sharding_tree.begin();
+ TF_ASSIGN_OR_RETURN(
+ sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
+ user_sharding_tree));
+ } else {
+ // Non-tuple shape: assign common users sharding.
+ TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
+ << "Expected non-tuple user sharding";
+ TF_ASSIGN_OR_RETURN(
+ sub_assigned,
+ AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
+ user_sharding_tree));
+ }
+
+ if (sub_assigned == AssignmentKind::kConflict) {
+ // In case of conflict we don't assign any sharding.
+ return false;
+ } else if (sub_assigned == AssignmentKind::kAssigned) {
+ assigned = sub_assigned;
+ }
+ }
+
+ if (assigned == AssignmentKind::kAssigned) {
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ instruction->set_sharding(HloSharding::Tuple(sharding_tree));
+ } else {
+ TF_RET_CHECK(sharding_tree.leaf_count() == 1);
+ instruction->set_sharding(sharding_tree.leaf_begin()->second);
+ }
+ return true;
}
- // At this point operand is a kDomain of the currently processed domain, so we
- // can refer to sharding as the domain sharding.
- return &sharding;
+ return false;
}
// Tries to propagate the sharding information into the instructions that are
-// part of the domain, in a post order manner (operand propagate to user).
+// part of the domain, in a reverse post order manner (users propoagate to
+// instruction).
StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
- const HloSharding& sharding) {
+ const HloSharding& domain_sharding) {
int64 assigned = 0;
- for (HloInstruction* instruction : domain.instructions) {
+ // domain.instructions are ordered in a post-order manner. As we do
+ // user->operand propagation we process instructions in reverse order. In so
+ // doing we are guaranteed to process all users before their operands.
+ for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
+ ++it) {
+ HloInstruction* instruction = *it;
if (instruction->has_sharding()) {
continue;
}
- if (instruction->opcode() == HloOpcode::kGetTupleElement) {
- HloInstruction* tuple = instruction->mutable_operand(0);
- const HloSharding* tuple_sharding =
- GetOperandSharding(tuple, domain, sharding);
- if (tuple_sharding != nullptr) {
- if (tuple_sharding->IsTuple()) {
- HloSharding sub_sharding = tuple_sharding->GetSubSharding(
- tuple->shape(), {instruction->tuple_index()});
- VLOG(4) << " " << instruction->name() << " to sharding "
- << sub_sharding;
- instruction->set_sharding(sub_sharding);
- } else {
- SetSingleSharding(instruction, *tuple_sharding);
- }
- ++assigned;
- }
- } else if (instruction->opcode() == HloOpcode::kTuple) {
- int64 tuple_assigned = 0;
- ShapeTree<HloSharding> shape_tree = GetTupleSharding(instruction);
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const HloSharding* operand_sharding =
- GetOperandSharding(instruction->operand(i), domain, sharding);
- if (operand_sharding != nullptr) {
- HloSharding operand_subsharding = HloSharding::Replicate();
- if (operand_sharding == &sharding) {
- operand_subsharding =
- sharding.GetSubSharding(instruction->shape(), {i});
- operand_sharding = &operand_subsharding;
- }
- if (shape_tree.element({i}) != *operand_sharding) {
- *shape_tree.mutable_element({i}) = *operand_sharding;
- ++tuple_assigned;
- }
- }
- }
- if (tuple_assigned > 0) {
- HloSharding tuple_sharding = HloSharding::Tuple(shape_tree);
- VLOG(4) << " " << instruction->name() << " to sharding "
- << tuple_sharding;
- instruction->set_sharding(tuple_sharding);
- ++assigned;
- }
- } else {
- // If all the operand of the given instruction has the same single device
- // assignment, assign that device to this instruction as well.
- const HloSharding* common_sharding = nullptr;
- for (const HloInstruction* operand : instruction->operands()) {
- const HloSharding* operand_sharding =
- GetOperandSharding(operand, domain, sharding);
- if (operand_sharding != nullptr) {
- if (common_sharding != nullptr &&
- *common_sharding != *operand_sharding) {
- common_sharding = nullptr;
- break;
- }
- common_sharding = operand_sharding;
- }
- }
- if (common_sharding != nullptr) {
- VLOG(4) << " " << instruction->name() << " to sharding "
- << *common_sharding;
- instruction->set_sharding(*common_sharding);
- ++assigned;
- }
+ // Take the sharding from the users.
+ TF_ASSIGN_OR_RETURN(
+ bool instruction_assigned,
+ ApplyShardingFromUsers(instruction, domain, domain_sharding));
+ if (instruction_assigned) {
+ ++assigned;
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << instruction->sharding();
}
}
return assigned;
@@ -261,83 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
return ApplyDomainSingleSharding(domain, *single_sharding);
}
VLOG(1) << "Assigning non-trivial sharding " << sharding;
- for (;;) {
- TF_ASSIGN_OR_RETURN(int64 assigned,
- ApplyDomainShardingPass(domain, sharding));
- if (assigned == 0) {
- break;
- }
- }
+ TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());
+
int64 unassigned = 0;
for (HloInstruction* instruction : domain.instructions) {
if (!instruction->has_sharding()) {
LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
++unassigned;
+ } else {
+ // Un-set sharding of tuples whose sub-sgardings are assigned to
+ // kUnassignedDevice. Indeed in case of doubt it is better to leave the
+ // entire tuple unassigned, and let the device placer decide for it.
+ if (instruction->sharding().UsesDevice(kUnassignedDevice)) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()))
+ << "Only tuples can have kUnassignedDevice sub shardings";
+ instruction->clear_sharding();
+ }
}
}
// Should we error out if unassigned > 0?
return Status::OK();
}
-// Creates a kDomain instruction to be placed between instruction and operand.
-// The kDomain instruction will be created only if the sharding differ between
-// the instruction and the operand.
-std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
- HloInstruction* operand) {
- const HloSharding* instruction_sharding =
- instruction->has_sharding() ? &instruction->sharding() : nullptr;
- const HloSharding* operand_sharding =
- operand->has_sharding() ? &operand->sharding() : nullptr;
- // No need for domain if they both have no sharding.
- if (instruction_sharding == nullptr && operand_sharding == nullptr) {
- return nullptr;
- }
- // No need for domain if they match.
- if (instruction_sharding != nullptr && operand_sharding != nullptr &&
- ShardingMatches(*instruction_sharding, *operand_sharding)) {
- return nullptr;
- }
- std::unique_ptr<HloSharding> real_instruction_sharding;
- std::unique_ptr<HloSharding> real_operand_sharding;
- if (instruction_sharding != nullptr) {
- real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
- }
- if (operand_sharding != nullptr) {
- real_operand_sharding = CloneShardingForDomain(*operand_sharding);
- }
- VLOG(3) << "Creating domain:";
- VLOG(3) << " Instruction: " << instruction->name();
- VLOG(3) << " Operand: " << operand->name();
- VLOG(3) << " User side sharding: "
- << (real_instruction_sharding != nullptr
- ? real_instruction_sharding->ToString()
- : "None");
- VLOG(3) << " Operand side sharding: "
- << (real_operand_sharding != nullptr
- ? real_operand_sharding->ToString()
- : "None");
-
- std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_operand_sharding));
- std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding));
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
-}
-
-StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
+StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
// If we are here, all the instructions being passed had the same sharding
// (or no sharding), by the means of the ShardingMatches() API.
// As such, no kDomain was inserted, and here we are asked to extract the
// original common sharding.
// All the instructions passed to this API are part of the same computation.
- const HloSharding* sharding = nullptr;
+ std::shared_ptr<const HloSharding> sharding;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
if (sharding == nullptr) {
- sharding = &instruction->sharding();
+ sharding = instruction->sharding_ptr();
} else {
TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
<< "Sharding " << *sharding << " does not match the one in "
@@ -346,10 +391,10 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
}
}
if (sharding == nullptr) {
- return std::unique_ptr<HloSharding>();
+ return std::shared_ptr<const HloSharding>();
}
VLOG(4) << "Extracted sharding is " << *sharding;
- return CloneShardingForDomain(*sharding);
+ return CloneShardingForDomain(sharding);
}
} // namespace
@@ -357,9 +402,9 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
std::unique_ptr<HloSharding> sharding;
if (sharding_ != nullptr) {
- sharding = MakeUnique<HloSharding>(*sharding_);
+ sharding = absl::make_unique<HloSharding>(*sharding_);
}
- return MakeUnique<ShardingMetadata>(std::move(sharding));
+ return absl::make_unique<ShardingMetadata>(std::move(sharding));
}
bool ShardingMetadata::Matches(const DomainMetadata& other) const {
@@ -403,7 +448,7 @@ Status ShardingMetadata::NormalizeShardingDomain(
TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
}
} else {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
ExtractOriginalCommonSharding(domain.instructions));
if (sharding != nullptr) {
VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
@@ -415,9 +460,75 @@ Status ShardingMetadata::NormalizeShardingDomain(
return Status::OK();
}
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand) {
- return CreateDomain(instruction, operand);
+// Creates a kDomain instruction to be placed between instruction and operand.
+// The kDomain instruction will be created only if the sharding differ between
+// the instruction and the operand.
+HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
+ auto instruction_sharding = instruction->sharding_ptr();
+ auto root_sharding = root->sharding_ptr();
+ // No need for domain if they both have no sharding.
+ if (instruction_sharding == nullptr && root_sharding == nullptr) {
+ return nullptr;
+ }
+ // No need for domain if they match.
+ if (instruction_sharding != nullptr && root_sharding != nullptr &&
+ ShardingMatches(*instruction_sharding, *root_sharding)) {
+ return nullptr;
+ }
+
+ if (instruction_sharding != nullptr) {
+ instruction_sharding = CloneShardingForDomain(instruction_sharding);
+ }
+ if (root_sharding != nullptr) {
+ root_sharding = CloneShardingForDomain(root_sharding);
+ }
+
+ auto it = domain_cse_map_.find({operand, instruction_sharding});
+ if (it != domain_cse_map_.end()) {
+ return it->second;
+ }
+
+ VLOG(3) << "Creating domain:";
+ VLOG(3) << " Instruction: " << instruction->name();
+ VLOG(3) << " Operand: " << operand->name();
+ VLOG(3) << " User side sharding: "
+ << (instruction_sharding != nullptr ? instruction_sharding->ToString()
+ : "None");
+ VLOG(3) << " Operand side sharding: "
+ << (root_sharding != nullptr ? root_sharding->ToString() : "None");
+
+ HloInstruction* domain =
+ operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand,
+ absl::make_unique<ShardingMetadata>(root_sharding),
+ absl::make_unique<ShardingMetadata>(instruction_sharding)));
+ domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
+ domain);
+ return domain;
+}
+
+bool ShardingDomainCreator::DomainCseMapKey::operator==(
+ const ShardingDomainCreator::DomainCseMapKey& other) const {
+ if (instruction != other.instruction) {
+ return false;
+ }
+ if (sharding == nullptr && other.sharding == nullptr) {
+ return true;
+ }
+ if (sharding == nullptr || other.sharding == nullptr) {
+ return false;
+ }
+ return *sharding == *other.sharding;
+}
+
+size_t ShardingDomainCreator::DomainCseMapHasher::operator()(
+ const ShardingDomainCreator::DomainCseMapKey& key) const {
+ return tensorflow::Hash64Combine(
+ std::hash<const HloInstruction*>{}(key.instruction),
+ key.sharding ? key.sharding->Hash()
+ : static_cast<size_t>(0x297814aaad196e6dULL));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index 5e01fc0e22..7a6b0d9abc 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -27,12 +27,12 @@ namespace xla {
// A DomainMetadata implementation that internally wraps a sharding attribute.
class ShardingMetadata : public DomainMetadata {
public:
- explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding)
+ explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding)
: sharding_(std::move(sharding)) {}
std::unique_ptr<DomainMetadata> Clone() const override;
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override;
@@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata {
const HloSharding* sharding() const { return sharding_.get(); }
- static tensorflow::StringPiece KindName() { return "sharding"; }
+ static absl::string_view KindName() { return "sharding"; }
static StatusOr<const ShardingMetadata*> ToShardingMetadata(
const DomainMetadata* metadata);
@@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata {
const DomainMetadata* metadata);
private:
- std::unique_ptr<HloSharding> sharding_;
+ std::shared_ptr<const HloSharding> sharding_;
};
-// Given an HLO graph edge between instruction and one of its operands, creates
-// a ShardingMetadata based kDomain instruction if the sharding between
-// instruction and operand changes. Returns nullptr if there is no need for a
-// domain separation.
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand);
+// If the sharding between root and instruction changes then returns a
+// ShardingMetadata based kDomain instruction what can be used to separate
+// operand and instruction.
+// Returns nullptr if there is no need for a domain separation.
+class ShardingDomainCreator {
+ public:
+ HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
+ HloInstruction* operand);
+
+ private:
+ // Map from instruction and user sharding to domain users to CSE identical
+ // domains.
+ struct DomainCseMapKey {
+ const HloInstruction* instruction;
+ std::shared_ptr<const HloSharding> sharding;
+
+ bool operator==(const DomainCseMapKey& other) const;
+ };
+ struct DomainCseMapHasher {
+ size_t operator()(const DomainCseMapKey& key) const;
+ };
+ std::unordered_map<DomainCseMapKey, HloInstruction*, DomainCseMapHasher>
+ domain_cse_map_;
+};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 45fc300fca..2341f8ada0 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) {
}
}
+// Tests that empty tuple is supported.
+TEST_F(HloShardingTest, EmptySingleTuple) {
+ HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
+ HloSharding::AssignDevice(0));
+ EXPECT_TRUE(sharding.ExtractSingleSharding());
+}
+
TEST_F(HloShardingTest, NestedTuple) {
// nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index 2ef38821af..d1cf644f82 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -24,7 +24,7 @@ namespace xla {
// one arbitrarily to use and delete the others.
class HloSubcomputationUnification : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "subcomputation-unification";
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index b78bfa0cdf..4876533449 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -21,28 +23,25 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-using ::tensorflow::GraphDef;
-using ::tensorflow::NodeDef;
-using ::tensorflow::TensorShapeProto;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-using ::tensorflow::str_util::Join;
namespace xla {
namespace hlo_graph_dumper {
namespace {
+using absl::StrAppend;
+using absl::StrCat;
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+using tensorflow::TensorShapeProto;
+
string GetOpDefName(const HloInstruction* instruction) {
string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
- tensorflow::str_util::TitlecaseString(&name, "-");
+ tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok
name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
if (instruction->opcode() == HloOpcode::kFusion) {
string fusion_name = ToString(instruction->fusion_kind());
- StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
+ StrAppend(&name, absl::string_view(fusion_name).substr(1));
}
return name;
}
@@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
} else {
layout_string = StrCat(
- "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}");
+ "{",
+ absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","),
+ "}");
}
attrs["layout"].set_s(layout_string);
}
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 7fd99fc930..e0c1326177 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -30,16 +32,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
const Shape& HloPosition::shape() const {
return ShapeUtil::GetSubshape(instruction->shape(), index);
@@ -216,10 +215,11 @@ void HloValueSet::SortAndUniquifyValues() {
}
string HloValueSet::ToString() const {
- return StrCat("HloValueSet: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return StrCat(
+ "HloValueSet: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
bool HloValueSet::AssignUnionOf(
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index ac1a663633..f1b29c2559 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,11 +15,13 @@ limitations under the License.
#include <set>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -115,6 +117,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
ShapeInference::InferAllToAllTupleShape(operand_shapes));
}
+Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
+ return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
+ hlo->operand(0)->shape()));
+}
+
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
@@ -122,39 +129,32 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-namespace {
-
-Status CheckIsTokenOperand(const HloInstruction* instruction,
- int64 operand_no) {
+Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
- "Expected operand %lld to be token-shaped, actual shape is "
+ "Expected operand %d to be token-shaped, actual shape is "
"%s:\n%s",
- operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
- instruction->ToString().c_str());
+ operand_no, StringifyShape(token->shape()), instruction->ToString());
}
return Status::OK();
}
-Status CheckOperandAndParameter(const HloInstruction* instruction,
- int64 operand_number,
- const HloComputation* computation,
- int64 parameter_number) {
+Status ShapeVerifier::CheckOperandAndParameter(
+ const HloInstruction* instruction, int64 operand_number,
+ const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
- if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
- operand->ToString().c_str(),
- parameter->ToString().c_str(),
- instruction->ToString().c_str());
+ operand->ToString(), parameter->ToString(),
+ instruction->ToString());
}
return Status::OK();
}
-} // namespace
-
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -171,22 +171,16 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
- if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
- outfeed->operand(0)->shape())) {
+ if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed shape to be compatible with operand's shape %s, "
+ "Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
- outfeed->ToString().c_str());
+ StringifyShape(outfeed->operand(0)->shape()),
+ StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
-Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
- return Status::OK();
-}
-
bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
const Shape& shape_1,
const Shape& result_shape) {
@@ -200,7 +194,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
if (instruction->operand_count() != 2) {
return InternalError("Expected two operands for Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
const Shape& shape_0 = instruction->operand(0)->shape();
@@ -208,14 +202,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
return InternalError(
"Expected scalar types for the two operands of Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
return InternalError(
"Expected compatible element types for the result and the two operands"
" of Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
PrimitiveType element_type = shape_0.element_type();
@@ -228,7 +222,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
"Element type not supported."
" Expected element to be of floating point type, integral type or"
" predicate type for RngUniform: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
break;
@@ -237,13 +231,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
return InternalError(
"Element type not supported."
" Expected element to be FloatingPointType for RngNormal: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
break;
default:
return InternalError(
"Invalid Rng distribution %s",
- RandomDistribution_Name(instruction->random_distribution()).c_str());
+ RandomDistribution_Name(instruction->random_distribution()));
}
return Status::OK();
@@ -262,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
return InternalError(
"Expected sort to have to have the same dimensions for the keys and "
"the values. Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ StringifyShape(sort->operand(0)->shape()),
+ StringifyShape(sort->operand(1)->shape()));
}
return CheckVariadicShape(sort);
}
@@ -272,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
return CheckShape(constant, constant->literal().shape());
}
-Status ShapeVerifier::HandleIota(HloInstruction* iota) {
- return ShapeUtil::Rank(iota->shape()) == 1
- ? Status::OK()
- : InternalError("Iota only supports arrays of rank 1.");
+Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ const int64 rank = ShapeUtil::Rank(iota->shape());
+ if (rank == 0) {
+ return InternalError("Iota does not support scalars.");
+ }
+ int64 iota_dimension = iota->iota_dimension();
+ if (iota_dimension >= rank) {
+ return InternalError(
+ "The iota dimension cannot go beyond the operation rank.");
+ }
+ return Status::OK();
}
Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
@@ -337,7 +339,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
+ for (HloInstruction* fused_param : fusion->fused_parameters()) {
+ int64 param_no = fused_param->parameter_number();
+ if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
+ return InternalError(
+ "Shape mismatch between parameter number %d and its operand in "
+ "%s.",
+ param_no, fusion->ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
@@ -419,12 +432,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
- if (!ShapeUtil::Compatible(conditional_shape,
- ShapeUtil::MakeShape(PRED, {}))) {
+ if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
- ShapeUtil::HumanString(conditional_shape).c_str());
+ StringifyShape(conditional_shape));
}
// The shape of kWhile should match the shape of the body computation it
// calls.
@@ -555,7 +567,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
return InternalError(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
- instruction->ToString().c_str());
+ instruction->ToString());
}
return Status::OK();
}));
@@ -602,53 +614,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
// Check if the output shape matches the expected shape.
- bool compatible;
+ //
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
- switch (instruction->opcode()) {
- case HloOpcode::kTupleSelect:
- // TupleSelect only defines the top-level buffer, which in this case is
- // the tuple, so we cannot allow mixed precision.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- case HloOpcode::kGetTupleElement:
- case HloOpcode::kTuple:
- // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
- // precision is disallowed.
- case HloOpcode::kConstant:
- case HloOpcode::kBitcast:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCall:
- case HloOpcode::kConditional:
- case HloOpcode::kConvert:
- case HloOpcode::kCustomCall:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kParameter:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kWhile:
- // The above opcodes should match the expected shapes exactly.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- default:
- if (allow_mixed_precision_) {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- } else {
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- }
- }
- if (!compatible) {
+ bool equal = [&] {
+ switch (instruction->opcode()) {
+ // The opcodes below can't have implicit layout conversions, nor can they
+ // implicitly transform f32 -> bf16. Fundamentally these are either
+ // reinterpreting existing data (e.g. kBitcast) or shuffling data around
+ // without modifying it (e.g. kGetTupleElement, kTupleSelect).
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return ShapesSame(instruction->shape(), inferred_shape);
+
+ // We allow arbitrary layout and f32->bf16 transformations on all other
+ // instructions, although this may be made more strict pending discussion
+ // in b/112709536.
+ default:
+ if (allow_mixed_precision_) {
+ return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
+ inferred_shape);
+ } else {
+ return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ }();
+ if (!equal) {
return InternalError(
- "Expected instruction to have shape compatible with %s, actual "
+ "Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(inferred_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- instruction->ToString().c_str());
+ StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
+ instruction->ToString());
}
return Status::OK();
}
@@ -692,10 +702,10 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
string ComputationsToString(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- return tensorflow::str_util::Join(
- computations, ",", [](string* s, const HloComputation* computation) {
- s->append(computation->name());
- });
+ return absl::StrJoin(computations, ",",
+ [](string* s, const HloComputation* computation) {
+ s->append(computation->name());
+ });
}
// Verifies various invariants about the structure of the HLO:
@@ -713,23 +723,23 @@ Status VerifyHloStructure(HloModule* module) {
for (const HloComputation* computation : module->computations()) {
if (computation->parent() == nullptr) {
return InternalError("Computation %s has a null parent pointer",
- computation->name().c_str());
+ computation->name());
}
if (computation->parent() != module) {
return InternalError(
"Computation %s parent() does not point to parent module",
- computation->name().c_str());
+ computation->name());
}
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->parent() == nullptr) {
return InternalError("Instruction %s has a null parent pointer",
- instruction->name().c_str());
+ instruction->name());
}
if (instruction->parent() != computation) {
return InternalError(
"Instruction %s parent() does not point to parent computation",
- instruction->name().c_str());
+ instruction->name());
}
}
}
@@ -746,9 +756,8 @@ Status VerifyHloStructure(HloModule* module) {
return InternalError(
"Operand %d (%s) of instruction %s is in a different "
"computation: %s vs %s",
- i, operand->name().c_str(), instruction->name().c_str(),
- operand->parent()->name().c_str(),
- instruction->parent()->name().c_str());
+ i, operand->name(), instruction->name(),
+ operand->parent()->name(), instruction->parent()->name());
}
}
}
@@ -764,7 +773,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
"Instruction of fused computation does not match expected "
"instruction "
"%s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
// Fused root instruction and fused parameters must all be owned by the
@@ -778,7 +787,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (fused_root == instruction) {
if (root_owned) {
return InternalError("Root appears more than once in %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
root_owned = true;
}
@@ -786,7 +795,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (fused_parameters[i] == instruction) {
if (parameter_owned[i]) {
return InternalError("Parameter appears more than once in %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
parameter_owned[i] = true;
}
@@ -794,20 +803,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
if (!root_owned) {
return InternalError("Root not found in computation of %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
// Make sure all the parameter_owned entries are set
for (int i = 0; i < parameter_owned.size(); i++) {
if (!parameter_owned[i]) {
return InternalError("Parameter %d not found in computation of %s.", i,
- fusion->ToString().c_str());
+ fusion->ToString());
}
}
// Fused root must have no users.
if (fused_root->user_count() != 0) {
- return InternalError("Root of %s may not have users.",
- fusion->ToString().c_str());
+ return InternalError("Root of %s may not have users.", fusion->ToString());
}
// All uses of fused instructions must be in the fusion computation, and
@@ -817,54 +825,46 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (instruction != fused_root) {
if (instruction->user_count() == 0) {
return InternalError("Non-root instruction %s in %s must have users.",
- instruction->ToString().c_str(),
- fusion->ToString().c_str());
+ instruction->ToString(), fusion->ToString());
}
for (auto& user : instruction->users()) {
if (fused_computation != user->parent()) {
return InternalError(
"Non-root instruction %s in %s may not have external users.",
- instruction->ToString().c_str(), fusion->ToString().c_str());
+ instruction->ToString(), fusion->ToString());
}
}
}
}
// Fused parameter instructions must be numbered contiguously and match up
- // (shapes compatible) with their respective operand.
+ // (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (param_no < 0) {
- return InternalError("Unexpected negative parameter number %lld in %s.",
- param_no, fusion->ToString().c_str());
+ return InternalError("Unexpected negative parameter number %d in %s.",
+ param_no, fusion->ToString());
}
if (param_no >= fused_parameters.size()) {
return InternalError(
- "Unexpected parameter number %lld in %s: higher then number of "
+ "Unexpected parameter number %d in %s: higher then number of "
"parameters %lu.",
- param_no, fusion->ToString().c_str(), fused_parameters.size());
+ param_no, fusion->ToString(), fused_parameters.size());
}
if (parameter_numbers[param_no]) {
return InternalError(
- "Did not expect parameter number %lld more than once in %s.",
- param_no, fusion->ToString().c_str());
+ "Did not expect parameter number %d more than once in %s.", param_no,
+ fusion->ToString());
}
parameter_numbers[param_no] = true;
- if (!ShapeUtil::Compatible(fused_param->shape(),
- fusion->operand(param_no)->shape())) {
- return InternalError(
- "Shape mismatch between parameter number %lld and its operand in "
- "%s.",
- param_no, fusion->ToString().c_str());
- }
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
if (!parameter_numbers[i]) {
return InternalError("Did not see parameter number %d in %s.", i,
- fusion->ToString().c_str());
+ fusion->ToString());
}
}
@@ -879,18 +879,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
auto* while_body = instruction->while_body();
if (while_cond->num_parameters() != 1) {
return FailedPrecondition(
- "While condition must have exactly 1 parameter; had %lld : %s",
- while_cond->num_parameters(), while_cond->ToString().c_str());
+ "While condition must have exactly 1 parameter; had %d : %s",
+ while_cond->num_parameters(), while_cond->ToString());
}
if (while_body->num_parameters() != 1) {
return FailedPrecondition(
- "While body must have exactly 1 parameter; had %lld : %s",
- while_body->num_parameters(), while_body->ToString().c_str());
+ "While body must have exactly 1 parameter; had %d : %s",
+ while_body->num_parameters(), while_body->ToString());
}
if (instruction->operand_count() != 1) {
return FailedPrecondition(
- "While loop must have exactly one operand; had %lld : %s",
- instruction->operand_count(), instruction->ToString().c_str());
+ "While loop must have exactly one operand; had %d : %s",
+ instruction->operand_count(), instruction->ToString());
}
return Status::OK();
}
@@ -898,16 +898,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
if (instruction->true_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "True computation %s of %s must have 1 parameter insted of %lld",
- instruction->true_computation()->name().c_str(),
- instruction->ToString().c_str(),
+ "True computation %s of %s must have 1 parameter insted of %d",
+ instruction->true_computation()->name(), instruction->ToString(),
instruction->true_computation()->num_parameters());
}
if (instruction->false_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "False computation %s of %s must have 1 parameter insted of %lld",
- instruction->false_computation()->name().c_str(),
- instruction->ToString().c_str(),
+ "False computation %s of %s must have 1 parameter insted of %d",
+ instruction->false_computation()->name(), instruction->ToString(),
instruction->false_computation()->num_parameters());
}
return Status::OK();
@@ -920,11 +918,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
- "Found non-compatible shapes for instruction %s.\n"
+ "Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
- HloOpcodeString(instruction->opcode()).c_str(),
- ShapeUtil::HumanString(out_shape).c_str(),
- ShapeUtil::HumanString(operand_shape).c_str());
+ HloOpcodeString(instruction->opcode()),
+ ShapeUtil::HumanString(out_shape),
+ ShapeUtil::HumanString(operand_shape));
}
}
return Status::OK();
@@ -955,7 +953,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
if (ShapeContainsToken(param->shape())) {
return InternalError(
"Entry parameter %d is or contains a token shape: %s", i,
- ShapeUtil::HumanString(param->shape()).c_str());
+ ShapeUtil::HumanString(param->shape()));
}
}
return Status::OK();
@@ -967,9 +965,9 @@ Status CheckSameChannel(const HloInstruction* instr1,
if (instr1->channel_id() != instr2->channel_id()) {
return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
- "(%lld), %s (%lld)",
- instr1->ToString().c_str(), instr1->channel_id(),
- instr2->ToString().c_str(), instr2->channel_id());
+ "(%d), %s (%d)",
+ instr1->ToString(), instr1->channel_id(), instr2->ToString(),
+ instr2->channel_id());
}
return Status::OK();
}
@@ -990,7 +988,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
"Expected instructions to have the same is-host-transfer property: "
"%s, "
"%s ",
- instr1->ToString().c_str(), instr2->ToString().c_str());
+ instr1->ToString(), instr2->ToString());
}
return Status::OK();
}
@@ -1007,12 +1005,12 @@ Status VerifySendsAndRecvs(const HloModule& module) {
host_channels.insert({sendrecv->channel_id(), sendrecv});
if (!it_inserted.second) {
return FailedPrecondition(
- "Channel %lld is used for multiple host send/recv instructions: "
+ "Channel %d is used for multiple host send/recv instructions: "
"%s "
"and "
"%s",
- sendrecv->channel_id(), sendrecv->ToString().c_str(),
- it_inserted.first->second->ToString().c_str());
+ sendrecv->channel_id(), sendrecv->ToString(),
+ it_inserted.first->second->ToString());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index c942fab08e..42e3027bf1 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
@@ -27,9 +28,9 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
- explicit ShapeVerifier() : allow_mixed_precision_(false) {}
- explicit ShapeVerifier(bool allow_mixed_precision)
- : allow_mixed_precision_(allow_mixed_precision) {}
+ explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
@@ -46,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;
+ Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override;
Status HandleInfeed(HloInstruction*) override;
Status HandleOutfeed(HloInstruction*) override;
@@ -63,7 +65,6 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFusion(HloInstruction*) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction*) override;
- Status HandleHostCompute(HloInstruction*) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
@@ -106,13 +107,42 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
- // Return true if the shapes of the two operands have the same element type,
- // and the result shape either has the same element type as the operand
- // shapes or mixed precision is allowed and the result shape and the operand
- // shapes have floating point element types.
+ // Helpers that switch on layout_sensitive_.
+ bool ShapesSame(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::Equal(a, b)
+ : ShapeUtil::Compatible(a, b);
+ }
+ bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
+ : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
+ }
+ string StringifyShape(const Shape& s) {
+ return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
+ : ShapeUtil::HumanString(s);
+ }
+
+ // Checks that the given operand of the given instruction is of type TOKEN.
+ Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Checks that the shape of the given operand of the given instruction matches
+ // the given parameter of the given computation.
+ Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number);
+
+ // Returns true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand shapes
+ // or mixed precision is allowed and the result shape and the operand shapes
+ // have floating point element types.
bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
const Shape& result_shape);
+ // If the verifier is layout-sensitive, shapes must be equal to what's
+ // expected. Otherwise, the shapes must simply be compatible.
+ bool layout_sensitive_;
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
@@ -125,14 +155,10 @@ class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- // Uses standard shape inference.
- explicit HloVerifier()
- : shape_verifier_factory_(
- [] { return MakeUnique<ShapeVerifier>(false); }) {}
-
- explicit HloVerifier(bool allow_mixed_precision)
- : shape_verifier_factory_([allow_mixed_precision] {
- return MakeUnique<ShapeVerifier>(allow_mixed_precision);
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
+ return absl::make_unique<ShapeVerifier>(layout_sensitive,
+ allow_mixed_precision);
}) {}
// Uses custom shape verification.
@@ -140,10 +166,9 @@ class HloVerifier : public HloPassInterface {
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
~HloVerifier() override = default;
- tensorflow::StringPiece name() const override { return "verifier"; }
+ absl::string_view name() const override { return "verifier"; }
- // Note: always returns false (no instructions are ever modified by this
- // pass).
+ // Never returns true; no instructions are ever modified by this pass.
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index d764964f3c..fc1f81bdd2 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -37,13 +37,15 @@ using ::testing::HasSubstr;
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
@@ -275,5 +277,84 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
}
+TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(F32).CloneToUnique())),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(F32).CloneToUnique())),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+// Simple module containing a convolution as the root.
+static const char* const kConvHloString = R"(
+HloModule module
+ENTRY entry_computation {
+ param0 = f16[128,128,56,56] parameter(0)
+ param1 = f16[3,3,128,128] parameter(1)
+ zero_f16 = f16[] constant(0)
+ ROOT conv = f16[128,128,28,28] convolution(param0, param1),
+ window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
+})";
+
+TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_window_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive window dilation factor"));
+}
+
+TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_base_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive base area dilation factor"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index bb5b40a8a8..e76b93107c 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -14,27 +14,27 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/metric_table_report.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
-using tensorflow::strings::Appendf;
+using absl::StrAppend;
+using absl::StrAppendFormat;
+using absl::StrCat;
+using absl::StrFormat;
using tensorflow::strings::HumanReadableElapsedTime;
using tensorflow::strings::HumanReadableNumBytes;
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
string HumanReadableProfileBuilder::ToString() const {
string s;
- Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n",
- computation_name_.c_str(),
- HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str());
+ StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n",
+ computation_name_,
+ HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)));
int64 cumulative_cycles = 0;
auto print_op = [&](const OpInfo& op, bool is_total = false) {
@@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const {
if (op.bytes_accessed > op.cycles) {
bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle");
} else {
- bytes_per_cycle = Printf("%.3fB/cycle", bpc);
+ bytes_per_cycle = StrFormat("%.3fB/cycle", bpc);
}
}
@@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const {
// columns in the output.
cycles_percent_str = "100.% 100Σ";
} else {
- cycles_percent_str =
- Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent);
+ cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent,
+ cumulative_cycles_percent);
}
double nsecs = op.cycles / clock_rate_ghz_;
- Appendf(
+ StrAppendFormat(
&s,
- "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: "
+ "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: "
"%16s :: %s\n",
- op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles),
+ op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles),
op.optimal_seconds < 0
? ""
- : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(),
- op.flop_count <= 0
- ? ""
- : HumanReadableNumFlops(op.flop_count, nsecs).c_str(),
+ : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6),
+ op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs),
op.transcendental_count <= 0
? ""
- : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs)
- .c_str(),
- bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str());
+ : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs),
+ bytes_per_sec, bytes_per_cycle, op.name);
};
float optimal_seconds_sum = 0.0;
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
index 6f56c3aa82..925111fa1f 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -29,10 +29,10 @@ namespace xla {
// computation, suitable for consumption by humans.
class HumanReadableProfileBuilder {
public:
- explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name,
+ explicit HumanReadableProfileBuilder(absl::string_view computation_name,
int64 total_cycles,
double clock_rate_ghz)
- : computation_name_(std::string(computation_name)),
+ : computation_name_(computation_name),
total_cycles_(total_cycles),
clock_rate_ghz_(clock_rate_ghz) {
CHECK_GE(clock_rate_ghz, 1e-9);
@@ -43,15 +43,13 @@ class HumanReadableProfileBuilder {
// Adds an operation to the profile. If you don't know the number of
// floating-point ops or bytes touched by the op, or if you don't know how
// fast it would run optimally, pass -1 for that param.
- void AddOp(tensorflow::StringPiece op_name,
- tensorflow::StringPiece short_name,
- tensorflow::StringPiece category, int64 cycles, int64 flop_count,
+ void AddOp(absl::string_view op_name, absl::string_view short_name,
+ absl::string_view category, int64 cycles, int64 flop_count,
int64 transcendental_count, int64 bytes_accessed,
float optimal_seconds) {
- op_infos_.push_back({std::string(op_name), std::string(short_name),
- std::string(category), cycles, flop_count,
- transcendental_count, bytes_accessed,
- optimal_seconds});
+ op_infos_.push_back({string(op_name), string(short_name), string(category),
+ cycles, flop_count, transcendental_count,
+ bytes_accessed, optimal_seconds});
}
// Gets the human-readable profile.
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index aa325dc8a3..85bb4a8b24 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface {
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "implicit-broadcast-remover";
}
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index f85d31d522..df88587492 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
+ public:
+ ImplicitBroadcastRemoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 8d17c03afc..43ef30d1eb 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -14,13 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/optional.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
@@ -31,32 +34,30 @@ using UnknownArray = Analysis::UnknownArray;
using ConstantArray = Analysis::ConstantArray;
using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
+using absl::StrJoin;
using tensorflow::gtl::ArraySlice;
-using tensorflow::str_util::Join;
} // namespace
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
auto* unknown_tensor = root->as<UnknownArray>();
- return tensorflow::strings::StrCat("%",
- unknown_tensor->instruction().name());
+ return absl::StrCat("%", unknown_tensor->instruction().name());
}
case Array::kConstant: {
if (print_constants) {
string contents = root->as<ConstantArray>()->literal()->ToString();
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
- ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ " ", contents, ")");
}
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ ")");
}
case Array::kReshaped: {
ReshapedArray* reshaped_array = root->as<ReshapedArray>();
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(reshape ", ToString(reshaped_array->operand(), print_constants),
" to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
}
@@ -67,11 +68,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
string name = root->kind() == Array::kScalarIndexedConstant
? "scalar-indexed-const"
: "scalar-indexed";
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(", name, " ", ToString(indexed_array->source(), print_constants),
" ", ToString(indexed_array->indices(), print_constants), " ",
indexed_array->source_dim(), "->[",
- Join(indexed_array->output_dims(), ","), "])");
+ StrJoin(indexed_array->output_dims(), ","), "])");
}
}
}
@@ -92,7 +93,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
// Depth first search over the DAG, invoking ComputeArrayFor in post order.
// The HLO instructions already in the cache are considered leaves.
- gtl::InlinedVector<const HloInstruction*, 4> stack;
+ absl::InlinedVector<const HloInstruction*, 4> stack;
enum DfsState { kDiscovered, kVisited };
gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
@@ -290,13 +291,13 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
int64 source_dim = dim_numbers.start_index_map(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.offset_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
output_dims.push_back(i);
}
}
if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
- if (c_linear_search(indexed->output_dims(), source_dim)) {
+ if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
shape);
}
@@ -314,7 +315,7 @@ namespace {
// [values.begin()+index, values.end()) is equal to `product`. If there is no
// such index, return -1. All integers in `values` must be positive.
int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
- DCHECK(c_all_of(values, [](int64 value) { return value > 0; }));
+ DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
int64 current_product = 1;
int64 i;
@@ -377,8 +378,8 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
CHECK_NE(candidate_operand_dim, 0)
<< "result_dim = " << result_dim
<< ", result_subarray_size = " << result_subarray_size
- << ", result_shape = [" << Join(result_shape, ",") << "]"
- << ", operand_shape = [" << Join(operand_shape, ",") << "]";
+ << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
+ << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
if (candidate_operand_dim != -1 &&
result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
@@ -388,26 +389,27 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
result_subarray_size *= result_shape[result_dim];
}
- c_reverse(result);
+ absl::c_reverse(result);
if (VLOG_IS_ON(3)) {
std::vector<string> result_strings;
- c_transform(result, std::back_inserter(result_strings),
- [](ReshapePassthroughDimPair value) {
- return tensorflow::strings::StrCat(value.result_dim, "->",
- value.operand_dim);
- });
- VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
- << Join(result_shape, ",") << "] passthrough indices are ["
- << Join(result_strings, ",") << "] (legend: `result`->`operand`)";
+ absl::c_transform(result, std::back_inserter(result_strings),
+ [](ReshapePassthroughDimPair value) {
+ return absl::StrCat(value.result_dim, "->",
+ value.operand_dim);
+ });
+ VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
+ << StrJoin(result_shape, ",") << "] passthrough indices are ["
+ << StrJoin(result_strings, ",")
+ << "] (legend: `result`->`operand`)";
}
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.result_dim < rhs.result_dim;
}));
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.operand_dim < rhs.operand_dim;
}));
@@ -419,20 +421,20 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// `passthrough_dims`.
bool IsReshapePassthroughOperandDim(
ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
- return c_any_of(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == dim;
- });
+ return absl::c_any_of(passthrough_dims,
+ [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == dim;
+ });
}
// Maps `operand_dim` which must be an passthrough operand dimension to its
// corresponding passthrough result dimension based on `passthrough_dims`.
int64 MapPassthroughOperandDimToResultDim(
ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
- auto it = c_find_if(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == operand_dim;
- });
+ auto it = absl::c_find_if(
+ passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == operand_dim;
+ });
CHECK(it != passthrough_dims.end());
return it->result_dim;
}
@@ -441,7 +443,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
ArraySlice<int64> result_shape,
int64 source_passthrough_dim) {
VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
- << Join(operand_shape, ",") << "], [" << Join(result_shape, ",")
+ << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
<< "], " << source_passthrough_dim << ")";
int64 indexed_source_subarray_size =
@@ -453,8 +455,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
Shape StripDegenerateDimensions(const Shape& shape) {
DimensionVector new_dims;
- c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
- [](int64 dim) { return dim != 1; });
+ absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
+ [](int64 dim) { return dim != 1; });
return ShapeUtil::MakeShape(shape.element_type(), new_dims);
}
}; // namespace
@@ -530,7 +532,7 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
// element is true iff the i'th component of the result index is an output
// index.
- gtl::InlinedVector<bool, 6> output_dims_bitvector(
+ absl::InlinedVector<bool, 6> output_dims_bitvector(
operand->shape().dimensions_size());
for (int64 output_dim : operand->output_dims()) {
output_dims_bitvector[output_dim] = true;
@@ -552,8 +554,8 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
}();
DimensionVector new_result_shape_dims;
- c_copy(operand->shape().dimensions(),
- std::back_inserter(new_result_shape_dims));
+ absl::c_copy(operand->shape().dimensions(),
+ std::back_inserter(new_result_shape_dims));
for (int64 degenerate_dim : degenerate_dims) {
InsertAt(&new_result_shape_dims, degenerate_dim, 1);
}
@@ -694,8 +696,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
operand_dim);
};
- if (!c_all_of(scalar_indexed->output_dims(),
- is_reshape_passthrough_operand_dim)) {
+ if (!absl::c_all_of(scalar_indexed->output_dims(),
+ is_reshape_passthrough_operand_dim)) {
VLOG(3) << "Not all output dims are passthrough dims "
<< ToString(scalar_indexed);
return nullptr;
@@ -753,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
if (source_dim_for_new_scalar_indexed_node == -1) {
VLOG(3) << "Could not compute the source dim for the new scalar indexed "
"node: scalar_indexed_source_shape = ["
- << Join(scalar_indexed_source_shape.dimensions(), ",")
+ << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
<< "] and new_scalar_indexed_source_shape = ["
- << Join(new_scalar_indexed_source_shape, ",") << "]";
+ << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
return nullptr;
}
@@ -763,8 +765,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
&new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
- CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL,
- std::multiplies<int64>()),
+ CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
+ std::multiplies<int64>()),
ShapeUtil::ElementsIn(scalar_indexed_source_shape));
CHECK(IsReshapePassthroughOperandDim(
@@ -780,9 +782,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
};
std::vector<int64> output_dims_for_new_scalar_indexed_node;
- c_transform(scalar_indexed->output_dims(),
- std::back_inserter(output_dims_for_new_scalar_indexed_node),
- map_passthrough_operand_dim_to_result_dim);
+ absl::c_transform(scalar_indexed->output_dims(),
+ std::back_inserter(output_dims_for_new_scalar_indexed_node),
+ map_passthrough_operand_dim_to_result_dim);
TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
TakeOwnership(scalar_indexed->literal().Reshape(
@@ -873,11 +875,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
auto is_broadcasted_dim = [&](int64 output_dim) {
- return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+ return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
};
// All of the output dims must be "broadcasted" dims for the other operand.
- if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+ if (!absl::c_all_of(scalar_indexed_const->output_dims(),
+ is_broadcasted_dim)) {
return nullptr;
}
@@ -969,15 +972,15 @@ namespace {
// Returns the non-contracting non-batch dimension (as per `contracting_dims`
// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
-gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
+absl::optional<int64> GetOnlyNonContractingNonBatchDim(
int64 rank, ArraySlice<int64> contracting_dims,
ArraySlice<int64> batch_dims) {
- gtl::optional<int64> result;
+ absl::optional<int64> result;
for (int64 dim = 0; dim < rank; dim++) {
if (!ArrayContains(contracting_dims, dim) &&
!ArrayContains(batch_dims, dim)) {
if (result.has_value()) {
- return gtl::nullopt;
+ return absl::nullopt;
}
result = dim;
}
@@ -994,10 +997,9 @@ gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
// `contracting_dims` and `batch_dims` are the contracting and batch dimensions
// of whatever operand `indexed_array` is to the dot (LHS or RHS).
bool CanFoldDotIntoIndexedArray(
- tensorflow::StringPiece tag,
- Analysis::ScalarIndexedConstantArray* indexed_array,
+ absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
- gtl::optional<int64> non_contracting_non_batch_dim =
+ absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
contracting_dims, batch_dims);
if (!non_contracting_non_batch_dim.has_value()) {
@@ -1132,7 +1134,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
return nullptr;
}
-tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
+absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 675eb31d26..3fa7d749e1 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -371,7 +371,7 @@ class IndexedArrayAnalysis {
// unconditionally add to the regular HLO pass pipeline.
class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 97052edf7d..c34c32f7d3 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
+ public:
+ IndexedArrayAnalysisTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
@@ -634,9 +639,9 @@ ENTRY main {
AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
(scalar-indexed-const (constant f32[3,4] f32[3,4] {
- { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
- { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
- { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+ { 0.761594, 0.964028, 0.995055, 0.999329 },
+ { 0.761594, 0.995055, 0.964028, 0.999329 },
+ { 0.999329, 0.995055, 0.964028, 0.761594 }
}) %indices 0->[0]))");
}
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index a523811f6c..efa8ed3abc 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -27,7 +27,7 @@ namespace xla {
class Inliner : public HloPassInterface {
public:
~Inliner() override = default;
- tensorflow::StringPiece name() const override { return "inline"; }
+ absl::string_view name() const override { return "inline"; }
// Run inlining on the given computation. Returns whether the computation was
// changed.
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 32937b33b3..5695bc2420 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index f33942d679..83313c7ec1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -121,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kConvolution:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDomain:
@@ -130,7 +132,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
- case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap:
@@ -189,13 +190,13 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (consumer == producer) {
return true;
}
- if (!consumer->IsFusable()) {
+ if (!consumer->IsFusible()) {
return false;
}
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
- // whether it's fusable.
+ // whether it's fusible.
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
@@ -205,7 +206,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
- // producer to be fusable into consumer on all paths.
+ // producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
@@ -216,7 +217,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
}
InstructionFusion::HloInstructionSet
-InstructionFusion::ComputeGloballyUnfusable(
+InstructionFusion::ComputeGloballyUnfusible(
tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
// Forbid fusion of producers that:
// a) Need to be duplicated, unless they can be fused into all consumers
@@ -270,19 +271,19 @@ InstructionFusion::ComputeGloballyUnfusable(
// all of its consumers on all paths.
//
// That means, that for:
- // A --> B (fusable)
- // \-> C (non-fusable)
+ // A --> B (fusible)
+ // \-> C (non-fusible)
// A will be not allowed to be fused into B, as it cannot be fused into C.
//
// Similarly, for:
// A -------------> B
// \-> C -> D -/
// If:
- // - A is fusable into B and C, and D is fusable into B
- // - C is *not* fusable into D
+ // - A is fusible into B and C, and D is fusible into B
+ // - C is *not* fusible into D
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
- if (producer->IsFusable() &&
+ if (producer->IsFusible() &&
CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
continue;
}
@@ -318,7 +319,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
InsertOrDie(&post_order_index, post_order[i], i);
}
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order);
+ HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
@@ -341,7 +342,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// consistent.
post_order_index.erase(instruction);
- if (!instruction->IsFusable() &&
+ if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
@@ -413,7 +414,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
- if (!operand->IsFusable()) {
+ if (!operand->IsFusible()) {
continue;
}
@@ -497,7 +498,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput(
bool InstructionFusion::MultiOutputFusionCreatesCycle(
HloInstruction* producer, HloInstruction* consumer) {
- return c_any_of(
+ return absl::c_any_of(
consumer->operands(), [&](const HloInstruction* consumer_operand) {
// The fusion algorithm traverses the HLO graph in reverse post order.
// Thus `cosumers` is visited before its operands (including
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index f73ca9adf7..9802d4cfc1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface {
bool may_duplicate = true)
: is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
~InstructionFusion() override = default;
- tensorflow::StringPiece name() const override { return "fusion"; }
+ absl::string_view name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
@@ -122,7 +122,7 @@ class InstructionFusion : public HloPassInterface {
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
- HloInstructionSet ComputeGloballyUnfusable(
+ HloInstructionSet ComputeGloballyUnfusible(
tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 9e7a15f033..da1ad90959 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
.ValueOrDie());
}
-TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
+TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
@@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
}
-TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
+TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
// Make sure we do not duplicate the add, as we cannot fuse through the rng.
//
// p0 -> add -------------------------> sub
@@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
// A variant of the above that allows the algorithm to put add2 into the set
- // of unfusable ops to short-circuit the decision whether add1 should be fused
+ // of unfusible ops to short-circuit the decision whether add1 should be fused
// into sub2.
//
// /---------------\
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 8652599dc6..581f8d2e92 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -12,12 +12,11 @@ cc_library(
srcs = ["interpreter_transfer_manager.cc"],
hdrs = ["interpreter_transfer_manager.h"],
deps = [
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:generic_transfer_manager",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service/interpreter:platform_id",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -32,8 +31,6 @@ cc_library(
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
@@ -54,6 +51,7 @@ cc_library(
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains compiler registration
)
@@ -79,7 +77,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
@@ -91,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 9f8f4bda87..bb69cb9c47 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -69,8 +69,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
// Create executable from only the Hlo module.
std::unique_ptr<Executable> executable =
- xla::MakeUnique<InterpreterExecutable>(std::move(hlo_module),
- xla::MakeUnique<HloEvaluator>());
+ absl::make_unique<InterpreterExecutable>(
+ std::move(hlo_module), absl::make_unique<HloEvaluator>());
return std::move(executable);
}
@@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
se::interpreter::kXlaInterpreterPlatformId, []() {
- return xla::MakeUnique<xla::interpreter::InterpreterCompiler>();
+ return absl::make_unique<xla::interpreter::InterpreterCompiler>();
});
xla::ComputationPlacer::RegisterComputationPlacer(
se::interpreter::kXlaInterpreterPlatformId,
- []() { return xla::MakeUnique<xla::ComputationPlacer>(); });
+ []() { return absl::make_unique<xla::ComputationPlacer>(); });
return true;
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 8d40c08d55..2259dc1083 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
index d27cd7502f..7955ee5cf3 100644
--- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager()
static std::unique_ptr<xla::TransferManager>
CreateInterpreterTransferManager() {
- return xla::MakeUnique<xla::InterpreterTransferManager>();
+ return absl::make_unique<xla::InterpreterTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
index 2b44f30821..b732230fdd 100644
--- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/core/platform/macros.h"
@@ -33,4 +33,4 @@ class InterpreterTransferManager : public GenericTransferManager {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc
index 42c2c28997..c9b40d3c61 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.cc
+++ b/tensorflow/compiler/xla/service/interpreter/platform.cc
@@ -17,13 +17,14 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/status_macros.h"
-#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
@@ -70,15 +71,15 @@ port::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
XlaInterpreterPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
- auto executor = MakeUnique<StreamExecutor>(
- this, MakeUnique<XlaInterpreterExecutor>(config.plugin_config));
+ auto executor = absl::make_unique<StreamExecutor>(
+ this, absl::make_unique<XlaInterpreterExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
port::error::INTERNAL,
- port::Printf(
+ absl::StrFormat(
"failed initializing StreamExecutor for device ordinal %d: %s",
- config.ordinal, init_status.ToString().c_str())};
+ config.ordinal, init_status.ToString())};
}
return std::move(executor);
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 805fdb2d5b..5e5c93e3a2 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -26,9 +26,12 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -49,20 +52,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
-// For now moving only one API here, but we should have a single top level
-// anonymous namespace, instead of three or four spread all over this file.
-namespace {
-
-} // namespace
-
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
@@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
}
string BufferLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s",
- buffer_->ToString().c_str(),
- LayoutUtil::HumanString(layout_).c_str());
+ return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
+ LayoutUtil::HumanString(layout_));
}
OperandLayoutConstraint::OperandLayoutConstraint(
@@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint(
}
string OperandLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf(
- "OperandLayoutConstraint %s, operand %lld: %s",
- instruction_->name().c_str(), operand_no_,
- shape_layout_.ToString().c_str());
+ return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
+ instruction_->name(), operand_no_,
+ shape_layout_.ToString());
}
string ResultLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf("ResultLayoutConstraint: %s",
- shape_layout_.ToString().c_str());
+ return absl::StrFormat("ResultLayoutConstraint: %s",
+ shape_layout_.ToString());
}
LayoutConstraints::LayoutConstraints(
@@ -137,7 +129,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
}
auto& buffer_set =
buffer_sets_cache_
- .emplace(instruction, MakeUnique<PointsToSet::BufferSet>())
+ .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
.first->second;
const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
points_to_set.ForEachElement(
@@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
return FailedPrecondition(
"Layout of buffer %s cannot be constrained because buffer is not "
"array-shaped, has shape: %s",
- buffer.ToString().c_str(),
- ShapeUtil::HumanString(buffer.shape()).c_str());
+ buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
}
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
@@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
- buffer.ToString().c_str(),
- LayoutUtil::HumanString(curr_constraint.layout()).c_str(),
- LayoutUtil::HumanString(layout).c_str());
+ buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
+ LayoutUtil::HumanString(layout));
}
iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
} else {
@@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
}
if (curr_shape_layout->mandatory()) {
return FailedPrecondition(
- "Operand %lld of instruction %s already has a layout constraint "
+ "Operand %d of instruction %s already has a layout constraint "
"%s, cannot add incompatible constraint %s",
- operand_no, instruction->name().c_str(),
- curr_shape_layout->shape_layout().ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ operand_no, instruction->name(),
+ curr_shape_layout->shape_layout().ToString(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
}
@@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
// layouts beyond this immediate use and is complicated to handle.
if (OperandBufferForwarded(instruction, operand_no)) {
return FailedPrecondition(
- "Cannot constraint layout of operand %lld of instruction %s "
+ "Cannot constraint layout of operand %d of instruction %s "
"because instruction forwards operand's LogicalBuffer(s)",
- operand_no, instruction->name().c_str());
+ operand_no, instruction->name());
}
auto key = std::make_pair(instruction, operand_no);
@@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
return FailedPrecondition(
"Result of computation %s already has the layout constraint %s, "
"cannot add incompatible constraint %s",
- computation_->name().c_str(), curr_shape_layout->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ computation_->name(), curr_shape_layout->ToString(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
@@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout(
if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
return FailedPrecondition(
"Instruction %s of shape %s cannot be assigned incompatible layout %s",
- instruction->name().c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ instruction->name(), ShapeUtil::HumanString(instruction->shape()),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// Create a BufferLayoutConstraint for each array shape in the output of the
@@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const {
string LayoutConstraints::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
- computation_->name(), ":\n");
+ absl::StrAppend(&output, "LayoutConstraints for computation ",
+ computation_->name(), ":\n");
for (auto* instruction : computation_->MakeInstructionPostOrder()) {
- tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(),
- "\n");
+ absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
for (int64 i = 0; i < instruction->operand_count(); ++i) {
if (OperandLayout(instruction, i) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " operand (", i,
- "): ", OperandLayout(instruction, i)->ToString(), "\n");
+ absl::StrAppend(&output, " operand (", i,
+ "): ", OperandLayout(instruction, i)->ToString(), "\n");
}
}
for (const LogicalBuffer* buffer :
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
if (BufferLayout(*buffer) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " ", buffer->ToString(), " : ",
- LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
+ absl::StrAppend(&output, " ", buffer->ToString(), " : ",
+ LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
}
}
}
if (ResultLayout() != nullptr) {
- tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(),
- "\n");
+ absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
}
return output;
}
@@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter,
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
- parameter->ToString().c_str(), parameter_layout.ToString().c_str());
+ parameter->ToString(), parameter_layout.ToString());
}
return Status::OK();
}
@@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) {
constant->shape())) {
return InternalError(
"constant instruction %s does not match the layout of its literal %s",
- constant->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str());
+ constant->ToString(),
+ ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
}
return Status::OK();
}
@@ -908,13 +893,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
- instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str(),
- buffer->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction_subshape)
- .c_str(),
- ShapeUtil::HumanStringWithLayout(buffer->shape())
- .c_str());
+ instruction->name(), absl::StrJoin(index, ","),
+ buffer->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction_subshape),
+ ShapeUtil::HumanStringWithLayout(buffer->shape()));
}
}
}
@@ -998,17 +980,18 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()));
CHECK(ShapeUtil::IsArray(operand->shape()));
- if (instruction->IsElementwiseOnOperand(operand_no) &&
- !ShapeUtil::IsScalar(operand->shape()) &&
+ if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
- ShapeUtil::Rank(instruction->shape())) {
- // Assign operands the same layout as the instruction, so that
+ ShapeUtil::Rank(instruction->shape()) &&
+ InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) {
+ // Propagate the result layout to the operand layout if the instruction
+ // requires the same layout out for the result and the operand.
+ //
+ // For elementwise operations, using the same layout for the operands and
+ // the result also has the following benefits:
// 1) the elementwise operation can reuse its operand's buffer, and
// 2) the input and output elements can reuse the same linear index.
- //
- // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
- // from assigning the same layout to input and output.
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
if (instruction->opcode() == HloOpcode::kReshape) {
@@ -1031,13 +1014,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
*operand_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(operand_shape);
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
- return MakeUnique<Layout>(operand_shape.layout());
+ return absl::make_unique<Layout>(operand_shape.layout());
}
if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
*operand_shape.mutable_layout() = output_layout;
if (ShapeUtil::ReshapeIsBitcast(operand_shape,
output_shape_with_layout)) {
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
}
auto aligned_operand_shape =
@@ -1046,7 +1029,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
auto operand_layout = aligned_operand_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
@@ -1062,7 +1045,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
return nullptr;
@@ -1076,11 +1059,11 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
CHECK(ShapeUtil::IsArray(user->shape()) &&
ShapeUtil::IsArray(operand->shape()));
- if (user->IsElementwiseOnOperand(operand_no) &&
- !ShapeUtil::IsScalar(operand->shape()) &&
- ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
+ if (!ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
+ InstructionRequiresInputLayoutEqualToOutputLayout(user)) {
// Assign users the same layout as the operand.
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
if (user->opcode() == HloOpcode::kReshape) {
@@ -1103,13 +1086,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
*output_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(output_shape);
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
- return MakeUnique<Layout>(output_shape.layout());
+ return absl::make_unique<Layout>(output_shape.layout());
}
if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
*output_shape.mutable_layout() = operand_layout;
if (ShapeUtil::ReshapeIsBitcast(output_shape,
operand_shape_with_layout)) {
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
auto aligned_user_shape =
@@ -1118,7 +1101,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
auto user_layout = aligned_user_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
}
@@ -1134,7 +1117,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
}
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
return nullptr;
@@ -1385,7 +1368,7 @@ StatusOr<Layout> InferArrayLayout(
// This should not happen because we've assigned layouts to all
// instructions preceding this one.
return InternalError("LogicalBuffer %s does not have a layout",
- source_buffer->ToString().c_str());
+ source_buffer->ToString());
}
if (first_buffer_layout == nullptr) {
@@ -1400,9 +1383,8 @@ StatusOr<Layout> InferArrayLayout(
return FailedPrecondition(
"Array at index {%s} in instruction %s aliases buffers %s "
"and %s which have different layouts",
- tensorflow::str_util::Join(index, ",").c_str(),
- instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
- source_buffer->ToString().c_str());
+ absl::StrJoin(index, ","), instruction->name(),
+ source_buffers[0]->ToString(), source_buffer->ToString());
}
}
@@ -1570,7 +1552,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// present in the IR before layout assignment is a bug.
return InternalError(
"Unexpected bitcast operation seen during layout assignment: %s.",
- instruction->ToString().c_str());
+ instruction->ToString());
}
if (instruction->opcode() != HloOpcode::kInfeed) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
@@ -1822,6 +1804,107 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
return true;
}
+bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
+ const HloInstruction* instruction) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAnd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kBitcastConvert:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kClz:
+ case HloOpcode::kComplex:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCos:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kEq:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kFft:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kImag:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLe:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kLt:
+ case HloOpcode::kMap:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kNot:
+ case HloOpcode::kOr:
+ case HloOpcode::kXor:
+ case HloOpcode::kPad:
+ case HloOpcode::kPower:
+ case HloOpcode::kReal:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kReverse:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return true;
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConstant:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kCopy:
+ case HloOpcode::kDomain:
+ case HloOpcode::kDot:
+ case HloOpcode::kFusion:
+ case HloOpcode::kGather:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kIota:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReshape:
+ case HloOpcode::kRng:
+ case HloOpcode::kScatter:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kAfterAll:
+ case HloOpcode::kTrace:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kTuple:
+ return false;
+ }
+}
+
Status LayoutAssignment::Init() {
computation_layouts_.clear();
*entry_computation_layout_ = saved_entry_computation_layout_;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index f9e8dbea2f..cf545031d3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
- tensorflow::StringPiece name() const override { return "layout-assignment"; }
+ absl::string_view name() const override { return "layout-assignment"; }
// Assign layouts to the given module. Returns whether the module was changed
// (any layouts were changed).
StatusOr<bool> Run(HloModule* module) override;
+ // Returns true if the instruction requires that operands with the same rank
+ // as the output have to have the same layout as the output.
+ virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
+ const HloInstruction* instruction);
+
protected:
// These methods, invoked by PropagateConstraints, propagate a layout
// constraint to its neighbors (i.e. operands and users) in order to minimize
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index a16fa75e30..7505d7a5b3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase {
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
- std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) {
+ std::vector<int64> LayoutOf(HloModule* module, absl::string_view name) {
auto minor_to_major =
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
@@ -861,5 +861,115 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
+TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopySliceOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[4,5]{0,1} parameter(1)
+ slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
+ ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+
+ auto copy = FindInstruction(module.get(), "copy.1");
+ auto slice = FindInstruction(module.get(), "slice0");
+ EXPECT_EQ(slice->operand(0), copy);
+ EXPECT_TRUE(
+ LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout()));
+}
+
+TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[4,5]{0,1} parameter(1)
+ par2 = s32[2] parameter(2)
+ dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4}
+ ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+
+ auto copy = FindInstruction(module.get(), "copy.1");
+ auto dslice = FindInstruction(module.get(), "dslice0");
+ EXPECT_EQ(dslice->operand(0), copy);
+ EXPECT_TRUE(
+ LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout()));
+}
+
+TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopyConcatOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,8]{1,0} parameter(0)
+ par1 = f32[3,5]{0,1} parameter(1)
+ par2 = f32[3,3]{1,0} parameter(2)
+ concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
+ dimensions={1}
+ ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+
+ auto copy = FindInstruction(module.get(), "copy.1");
+ auto concat = FindInstruction(module.get(), "concat0");
+ EXPECT_EQ(concat->operand(0), copy);
+ EXPECT_TRUE(
+ LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout()));
+}
+
+TEST_F(LayoutAssignmentTest,
+ ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
+ const char* module_str = R"(
+ HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
+
+ ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
+ par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
+ par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
+ ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
+ window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
+ feature_group_count=1
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+
+ auto copy = FindInstruction(module.get(), "copy.1");
+ EXPECT_EQ(copy, nullptr);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index cdd3daf73b..be12d7c90c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -38,6 +38,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -69,6 +70,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
@@ -88,6 +90,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -103,6 +107,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -120,6 +125,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:core",
],
)
@@ -133,9 +139,7 @@ cc_library(
":llvm_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"@llvm//:core",
@@ -193,6 +197,8 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter",
"//tensorflow/compiler/xla/service/gpu:partition_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
],
)
@@ -219,7 +225,7 @@ cc_library(
deps = [
":llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -230,6 +236,7 @@ cc_library(
hdrs = ["buffer_assignment_util.h"],
deps = [
"//tensorflow/compiler/xla/service:buffer_assignment",
+ "@com_google_absl//absl/strings",
],
)
@@ -242,3 +249,12 @@ cc_library(
"@llvm//:core",
],
)
+
+cc_library(
+ name = "ir_builder_mixin",
+ srcs = [],
+ hdrs = ["ir_builder_mixin.h"],
+ deps = [
+ "@llvm//:core",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index fe9eab93aa..8d9fa99d82 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace llvm_ir {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
index 4eb5d9fb47..bdce4a171b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
+#include "absl/strings/str_cat.h"
namespace xla {
namespace llvm_ir {
@@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName(
c = '_';
}
}
- return tensorflow::strings::StrCat("buffer_for_", instr_name);
+ return absl::StrCat("buffer_for_", instr_name);
}
const Literal& LiteralForConstantAllocation(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index 27fbb11e2e..ad350613dd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
const Shape& update_shape, const ElementGenerator& start_indices_generator,
bool is_signed, ElementGenerator update_array_generator,
const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
- tensorflow::StringPiece name, llvm::IRBuilder<>* b) {
+ absl::string_view name, llvm::IRBuilder<>* b) {
const Shape& output_shape = output_array.GetShape();
// Read start indices from start_indices_generator.
@@ -101,8 +101,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
Status EmitDynamicUpdateSliceInPlace(
tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b) {
+ const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) {
VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
// No need to use operand_arrays[0], the input array of the
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index 3502577d23..e1631a62ae 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -65,8 +65,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace(
// modify the input/output buffer without touching any of the other elements.
Status EmitDynamicUpdateSliceInPlace(
tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+ const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b);
// Given a loop-fusion node whose root is a dynamic-update-slice op whose
// array-to-be-updated and output share the same buffer slice, emits
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 72ede377e1..6d637cad6d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement(
return Unimplemented(
"GetTupleElement fusion currently only supports"
" parameter operands, but found operand: %s",
- operand->name().c_str());
+ operand->name());
}
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 2b6caee6aa..6971220022 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -342,9 +342,9 @@ llvm::Value* IrArray::Index::Linearize(
return logical_linear_index;
}
-llvm::Value* IrArray::EmitArrayElementAddress(
- const IrArray::Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
+ llvm::IRBuilder<>* b,
+ absl::string_view name) const {
if (ShapeUtil::IsScalar(*shape_)) {
// Special handling of scalars: a scalar pretends to have the same value for
// every index, thus effectively implementing broadcasting of its value
@@ -402,7 +402,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+ absl::string_view name) const {
llvm::Value* element_address = EmitArrayElementAddress(index, b, name);
llvm::LoadInst* load = b->CreateLoad(element_address);
AnnotateLoadStoreInstructionWithMetadata(load);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index 28ca793e3e..e913c109b3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -19,12 +19,13 @@ limitations under the License.
#include <map>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -81,7 +82,7 @@ class IrArray {
}
}
CHECK_NE(index_type_, nullptr);
- CHECK(c_all_of(multidim, [&](llvm::Value* v) {
+ CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) {
return index_type_ == v->getType();
}));
}
@@ -240,7 +241,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Attach metadata this IrArray instance knows about to "instruction".
void AnnotateLoadStoreInstructionWithMetadata(
@@ -254,7 +255,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Emit IR to write the given value to the array element at the given index.
void EmitWriteArrayElement(const Index& index, llvm::Value* value,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
new file mode 100644
index 0000000000..abc06fb7b4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -0,0 +1,400 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+
+#include "llvm/IR/IRBuilder.h"
+
+namespace xla {
+
+// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
+// into a class. Intended to be used as a CRTP base class, like:
+//
+// class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
+// llvm::IRBuilder<>* builder() { return builder_; }
+//
+// void EmitFoo(HloInstruction* foo) {
+// Add(Mul(...), FPToUI(...));
+// }
+// };
+
+template <typename Derived>
+class IrBuilderMixin {
+ protected:
+ template <class... Args>
+ llvm::Value* Add(Args&&... args) {
+ return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* AlignedLoad(Args&&... args) {
+ return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* AlignedStore(Args&&... args) {
+ return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::AllocaInst* Alloca(Args&&... args) {
+ return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* And(Args&&... args) {
+ return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicCmpXchg(Args&&... args) {
+ return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicRMW(Args&&... args) {
+ return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* BitCast(Args&&... args) {
+ return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Br(Args&&... args) {
+ return mixin_builder()->CreateBr(std::forward<Args>(args)...);
+ }
+
+ llvm::CallInst* Call(llvm::Value* callee,
+ llvm::ArrayRef<llvm::Value*> args = llvm::None,
+ const llvm::Twine& name = "",
+ llvm::MDNode* fp_math_tag = nullptr) {
+ return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
+ }
+
+ template <class... Args>
+ llvm::BranchInst* CondBr(Args&&... args) {
+ return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
+ return mixin_builder()->CreateConstInBoundsGEP1_32(
+ std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FAdd(Args&&... args) {
+ return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FMul(Args&&... args) {
+ return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateGEP(ptr, idx_list, name);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpEQ(Args&&... args) {
+ return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpNE(Args&&... args) {
+ return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULE(Args&&... args) {
+ return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULT(Args&&... args) {
+ return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* InBoundsGEP(llvm::Value* ptr,
+ llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name);
+ }
+
+ llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateExtractValue(agg, idxs, name);
+ }
+
+ llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
+ llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
+ }
+
+ template <class... Args>
+ llvm::Value* IntToPtr(Args&&... args) {
+ return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* Load(Args&&... args) {
+ return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::CallInst* MemCpy(Args&&... args) {
+ return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Mul(Args&&... args) {
+ return mixin_builder()->CreateMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWAdd(Args&&... args) {
+ return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWMul(Args&&... args) {
+ return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWSub(Args&&... args) {
+ return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Or(Args&&... args) {
+ return mixin_builder()->CreateOr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PointerCast(Args&&... args) {
+ return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PtrToInt(Args&&... args) {
+ return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SDiv(Args&&... args) {
+ return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Select(Args&&... args) {
+ return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SRem(Args&&... args) {
+ return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* Store(Args&&... args) {
+ return mixin_builder()->CreateStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UDiv(Args&&... args) {
+ return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* URem(Args&&... args) {
+ return mixin_builder()->CreateURem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* VectorSplat(Args&&... args) {
+ return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ZExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AShr(Args&&... args) {
+ return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOEQ(Args&&... args) {
+ return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOLT(Args&&... args) {
+ return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpONE(Args&&... args) {
+ return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpUNE(Args&&... args) {
+ return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FDiv(Args&&... args) {
+ return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FNeg(Args&&... args) {
+ return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPCast(Args&&... args) {
+ return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToSI(Args&&... args) {
+ return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToUI(Args&&... args) {
+ return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPTrunc(Args&&... args) {
+ return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FRem(Args&&... args) {
+ return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FSub(Args&&... args) {
+ return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSGE(Args&&... args) {
+ return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSLT(Args&&... args) {
+ return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* IntCast(Args&&... args) {
+ return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* LShr(Args&&... args) {
+ return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* MemSet(Args&&... args) {
+ return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Neg(Args&&... args) {
+ return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Not(Args&&... args) {
+ return mixin_builder()->CreateNot(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::PHINode* PHI(Args&&... args) {
+ return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* RetVoid(Args&&... args) {
+ return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Shl(Args&&... args) {
+ return mixin_builder()->CreateShl(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SIToFP(Args&&... args) {
+ return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Sub(Args&&... args) {
+ return mixin_builder()->CreateSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Trunc(Args&&... args) {
+ return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UIToFP(Args&&... args) {
+ return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Unreachable(Args&&... args) {
+ return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Xor(Args&&... args) {
+ return mixin_builder()->CreateXor(std::forward<Args>(args)...);
+ }
+
+ private:
+ llvm::IRBuilder<>* mixin_builder() {
+ return static_cast<Derived*>(this)->builder();
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index b79567369a..bd0139f85b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace xla {
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value*, bool)>& for_body_generator) {
return If(b_->CreateICmpSLT(start, end), [&]() -> Status {
@@ -30,7 +30,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value*, llvm::Value*)>&
for_body_generator) {
@@ -56,7 +56,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::If(
- tensorflow::StringPiece name, llvm::Value* condition,
+ absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
@@ -70,7 +70,7 @@ Status KernelSupportLibrary::If(
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name,
+ absl::string_view kernel_name,
KernelSupportLibrary::ArgumentVector arguments,
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
kernel_body_generator) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index b00f903d56..b152cf9275 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
#include <string>
+#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
// A thin wrapper around llvm_loop.h to make code generating structured control
@@ -49,13 +49,13 @@ class KernelSupportLibrary {
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>& for_body_generator);
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
@@ -67,7 +67,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ Status For(absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>&
for_body_generator) {
@@ -77,7 +77,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
@@ -99,13 +99,13 @@ class KernelSupportLibrary {
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
for_body_generator);
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, llvm::Value* step,
bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
@@ -119,7 +119,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
int64 step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -129,7 +129,7 @@ class KernelSupportLibrary {
peel_first_iteration, for_body_generator);
}
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, int64 step, bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -140,7 +140,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, step,
@@ -151,7 +151,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end, step,
@@ -162,8 +162,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
/*peel_first_iteration=*/false,
@@ -173,8 +172,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end,
llvm::ConstantInt::get(start->getType(), step),
@@ -182,7 +180,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -190,7 +188,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -203,7 +201,7 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- Status If(tensorflow::StringPiece name, llvm::Value* condition,
+ Status If(absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator =
[]() -> Status { return Status::OK(); });
@@ -222,7 +220,7 @@ class KernelSupportLibrary {
IfReturnVoid("", condition, true_block_generator, false_block_generator);
}
- void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition,
+ void IfReturnVoid(absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {
}) {
@@ -259,13 +257,13 @@ class KernelSupportLibrary {
// Currently we only support at most one nullptr value in `arguments`.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, ArgumentVector arguments,
+ absl::string_view kernel_name, ArgumentVector arguments,
const std::function<void(ArgumentVector)>& kernel_body_generator);
// Thin wrappers around the more general EmitAndCallOutlinedKernel above.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
kernel_body_generator) {
@@ -278,7 +276,7 @@ class KernelSupportLibrary {
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2, llvm::Value* arg3,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
llvm::Value*)>& kernel_body_generator) {
@@ -296,4 +294,4 @@ class KernelSupportLibrary {
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index 35b3941272..cb4d1db997 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -55,10 +55,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
}
} // namespace
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
- const Shape& a, const Shape& b) {
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b) {
if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
std::vector<int64> perm(a.dimensions().size());
@@ -88,7 +88,7 @@ tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
return dims_021;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
IrArray::Index GetUnreducedOutputIndex(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index ccb9b8ba3e..8bd06c42c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -36,8 +36,8 @@ namespace llvm_ir {
// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
// reduced shape of `b` or the 0-2-1 shape.
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
- const Shape& b);
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b);
// Return the unreduced output index corresponding to the given reduced output
// index.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index ba7f94834c..9f3329e7f0 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -25,19 +26,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace llvm_ir {
-ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index,
llvm::Value* step, UnrollMode unroll_mode,
bool prevent_vectorization)
- : prefix_(std::string(prefix)),
- suffix_(std::string(suffix)),
+ : prefix_(prefix),
+ suffix_(suffix),
start_index_(start_index),
end_index_(end_index),
step_(step),
@@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
- llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
- UnrollMode unroll_mode, bool prevent_vectorization) {
+ absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
+ bool prevent_vectorization) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
end_index, step, unroll_mode,
prevent_vectorization));
@@ -168,16 +167,16 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
return result;
}
-string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
+string ForLoop::GetQualifiedName(absl::string_view name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
}
-llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
+llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
llvm::IRBuilder<>* b) {
return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode,
@@ -186,12 +185,9 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
unroll_mode, prevent_vectorization);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- llvm::Value* stride,
- UnrollMode unroll_mode,
- bool prevent_vectorization) {
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
+ absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
@@ -216,7 +212,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -227,7 +223,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -238,7 +234,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
}
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix) {
+ absl::string_view suffix) {
std::vector<int64> dimensions(ShapeUtil::Rank(shape));
std::iota(dimensions.begin(), dimensions.end(), 0);
return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
@@ -246,14 +242,14 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix) {
+ absl::string_view suffix) {
llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
/*start_index=*/0,
/*end_index=*/shape.dimensions(dimension),
/*suffix=*/
- llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension)));
+ llvm_ir::IrName(suffix, absl::StrCat(dimension)));
index[dimension] = loop->GetIndVarValue();
}
return index;
@@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
IrArray::Index ForLoopNest::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix) {
+ absl::string_view name_suffix) {
// Prepares the dimension list we will use to emit the loop nest. Outermost
// loops are added first. Add loops in major-to-minor order, and skip the
// 'dimension_to_skip' dimension.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index a4fed5c8dc..0a406bd90b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -19,15 +19,15 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -78,7 +78,7 @@ class ForLoop {
// `unroll_mode` specifies the desired LLVM unrolling behavior for generated
// loop.
static std::unique_ptr<ForLoop> EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
+ absl::string_view prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -133,19 +133,18 @@ class ForLoop {
// Allow ForLoopNest to call this private constructor.
friend class ForLoopNest;
- ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
UnrollMode unroll_mode, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* b);
- llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+ llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b);
// Creates a name for an LLVM construct, appending prefix_ and suffix_, if
// they are set.
- string GetQualifiedName(tensorflow::StringPiece name);
+ string GetQualifiedName(absl::string_view name);
// Return a list of metadata nodes that should be associated with the
// llvm::Loop for this `ForLoop`.
@@ -182,9 +181,9 @@ class ForLoopNest {
SetIndexType(index_ty);
}
- ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b,
+ ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b,
llvm::Type* index_ty = nullptr)
- : name_(std::string(name)),
+ : name_(name),
outer_loop_preheader_bb_(nullptr),
outer_loop_exit_bb_(nullptr),
inner_loop_body_bb_(nullptr),
@@ -197,14 +196,14 @@ class ForLoopNest {
// been added then emit loop inside the body of the last added loop.
// unroll_mode is used to emit metadata that controls LLVM unrolling.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* stride,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -213,13 +212,13 @@ class ForLoopNest {
// end index are constant.
std::unique_ptr<ForLoop> AddLoop(
int64 start_index, int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- int64 start_index, int64 end_index, tensorflow::StringPiece suffix,
+ int64 start_index, int64 end_index, absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -234,8 +233,7 @@ class ForLoopNest {
// within the shape. One possible order for that sequence would be:
//
// (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
- IrArray::Index AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix);
+ IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix);
// Add a loop for each dimension in "dimensions". "suffix" is the
// name suffix of the indvar and basic blocks in this new loop nest.
@@ -245,7 +243,7 @@ class ForLoopNest {
// dimension that is not in "dimensions".
IrArray::Index AddLoopsForShapeOnDimensions(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix);
+ absl::string_view suffix);
// Emits a series of nested loops for iterating over an operand array. Loops
// are constructed in major to minor dimension layout order. No loop is
@@ -256,7 +254,7 @@ class ForLoopNest {
// basic blocks) constructed by this method.
IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array,
int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix);
+ absl::string_view name_suffix);
// Convenience methods which return particular basic blocks of the outermost
// or innermost loops. These methods return nullptr if no loops have been
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index e6126881af..f0db2a3761 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/MDBuilder.h"
@@ -34,8 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -61,7 +61,7 @@ string AsString(const std::string& str) {
return string(str.data(), str.length());
}
-llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
+llvm::StringRef AsStringRef(absl::string_view str) {
return llvm::StringRef(str.data(), str.size());
}
@@ -262,15 +262,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment) {
return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
}
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment) {
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment) {
llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP();
llvm::Function* function = b->GetInsertBlock()->getParent();
b->SetInsertPoint(&function->getEntryBlock(),
@@ -285,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
}
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b) {
return llvm::BasicBlock::Create(
/*Context=*/b->getContext(),
@@ -294,27 +296,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
/*InsertBefore*/ insert_before);
}
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else) {
llvm_ir::LlvmIfData if_data;
if_data.if_block = b->GetInsertBlock();
if_data.true_block =
- CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b);
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
if_data.false_block =
- emit_else ? CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-false"), b)
+ emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
: nullptr;
// Add a terminator to the if block, if necessary.
if (if_data.if_block->getTerminator() == nullptr) {
b->SetInsertPoint(if_data.if_block);
- if_data.after_block = CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-after"), b);
+ if_data.after_block =
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
b->CreateBr(if_data.after_block);
} else {
if_data.after_block = if_data.if_block->splitBasicBlock(
- b->GetInsertPoint(),
- AsStringRef(tensorflow::strings::StrCat(name, "-after")));
+ b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after")));
}
// Our basic block should now end with an unconditional branch. Remove it;
@@ -413,14 +413,14 @@ string IrName(string a) {
return a;
}
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) {
+string IrName(absl::string_view a, absl::string_view b) {
if (!a.empty() && !b.empty()) {
- return IrName(tensorflow::strings::StrCat(a, ".", b));
+ return IrName(absl::StrCat(a, ".", b));
}
- return IrName(tensorflow::strings::StrCat(a, b));
+ return IrName(absl::StrCat(a, b));
}
-string IrName(const HloInstruction* a, tensorflow::StringPiece b) {
+string IrName(const HloInstruction* a, absl::string_view b) {
return IrName(a->name(), b);
}
@@ -556,7 +556,7 @@ std::map<int, llvm::MDNode*> MergeMetadata(
return result;
}
-static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) {
+static string GetProcessUniqueIrFileName(absl::string_view prefix) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-");
@@ -584,18 +584,16 @@ Status DumpIRToDirectory(const string& directory_name,
// XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
// dumped from the same process in such cases.
string unique_and_safe_file_name = GetProcessUniqueIrFileName(
- tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
- optimized ? "with" : "no", "-opt"));
+ absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
+ optimized ? "with" : "no", "-opt"));
string ir_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, ".ll"));
// For some models the embedded constants can be huge, so also dump the module
// with the constants stripped to get IR that is easier to manipulate.
string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll"));
TF_RETURN_IF_ERROR(CreateAndWriteStringToFile(
directory_name, ir_file_name, DumpModuleToString(llvm_module)));
@@ -607,8 +605,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module) {
+ absl::string_view name, llvm::Module* module) {
llvm::Function* function =
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
function->setCallingConv(llvm::CallingConv::C);
@@ -638,7 +635,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
fake_argv_storage.push_back("");
for (const auto& it : options) {
// Skip options the XLA backend itself consumes.
- if (!tensorflow::str_util::StartsWith(it.first, "xla_")) {
+ if (!absl::StartsWith(it.first, "xla_")) {
if (it.second.empty()) {
fake_argv_storage.push_back(it.first);
} else {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 0958398534..dde50e19d1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
@@ -47,11 +47,11 @@ namespace llvm_ir {
// Convert a std::string (used by LLVM's interfaces) to string.
string AsString(const std::string& str);
-// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both
-// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a
+// Convert a absl::string_view to a llvm::StringRef. Note: both
+// absl::string_view and llvm::StringRef are non-owning pointers into a
// string in memory. This method is used to feed strings to LLVM
// & Clang APIs that expect llvm::StringRef.
-llvm::StringRef AsStringRef(tensorflow::StringPiece str);
+llvm::StringRef AsStringRef(absl::string_view str);
template <typename T>
llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
@@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module);
// - removing all '%'s.
//
string IrName(string a);
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b);
-string IrName(const HloInstruction* a, tensorflow::StringPiece b = "");
+string IrName(absl::string_view a, absl::string_view b);
+string IrName(const HloInstruction* a, absl::string_view b = "");
// Removes special characters from a function name.
//
@@ -164,21 +164,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
// This can be useful to avoid e.g. executing an alloca every time
// through a loop.
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment = 0);
// As EmitAllocaAtFunctionEntry, but allocates element_count entries
// instead of a single element.
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment = 0);
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment = 0);
// Creates a basic block with the same context and function as for the
// builder. Inserts at the end of the function if insert_before is
// null.
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b);
// Struct with data on a conditional branch in a diamond shape created
@@ -210,7 +212,7 @@ struct LlvmIfData {
// Currently the insertion point of the builder must be a well-formed
// block with a terminator. If you need to use this for a
// non-terminated block, just make the function able to do that too.
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else = true);
// Emits a compare operation between "lhs" and "rhs" with the given predicate,
@@ -285,8 +287,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module);
+ absl::string_view name, llvm::Module* module);
// Extracts the xla_backend_extra_options from `config` and passes those that
// don't start with xla_ to LLVM.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 36f5fa1952..1553b4fc91 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
}
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ absl::string_view loop_name, llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
@@ -105,7 +105,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
std::unique_ptr<ForLoop> loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/shape_.dimensions(dimension),
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ /*suffix=*/absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
@@ -122,7 +122,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
return {array_index};
}
-Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name,
+Status LoopEmitter::EmitLoop(absl::string_view loop_name,
llvm::Type* index_type) {
if (index_type == nullptr) {
index_type = b_->getInt64Ty();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index c4f5c82086..57d9d8bbc6 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -69,10 +69,10 @@ class LoopEmitter {
}
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type);
+ absl::string_view loop_name, llvm::Type* index_type);
// Emits a complete loop nest for every element in the given shape.
- Status EmitLoop(tensorflow::StringPiece loop_name = "",
+ Status EmitLoop(absl::string_view loop_name = "",
llvm::Type* index_type = nullptr);
protected:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index e546f5cc4a..00dd3f1638 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -29,8 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -42,7 +42,7 @@ namespace {
void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
const IrArray::Index& compare_keys_index,
const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
+ const absl::optional<IrArray>& values_array,
llvm::IRBuilder<>* b) {
// if (is_smaller_index &&
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
@@ -87,8 +87,8 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
} // namespace
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ const absl::optional<IrArray>& values_array,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions) {
const Shape& keys_shape = keys_array.GetShape();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index 8458744c6b..527ed10374 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -31,8 +31,8 @@ namespace llvm_ir {
// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr,
// the inner compare loop will not be parallelized.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ const absl::optional<IrArray>& values_array,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions);
} // namespace llvm_ir
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 5e02096ee5..768105d9e1 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -19,10 +19,12 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -37,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -73,7 +74,7 @@ namespace {
// If the parameter number is invalid for this computation, nullopt is
// returned. When the return value has_value(), nullptr will never be
// the held value.
-tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+absl::optional<const OpMetadata*> ParameterMetadata(
const XlaComputation& computation, int parameter_number) {
for (const HloComputationProto& comp : computation.proto().computations()) {
if (comp.id() == computation.proto().entry_computation_id()) {
@@ -81,14 +82,14 @@ tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
instr.parameter_number() == parameter_number) {
if (!instr.has_metadata()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
return &instr.metadata();
}
}
}
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
ExecutionOptions CreateExecutionOptions(
@@ -149,7 +150,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Validate incoming layouts.
if (argument_layouts.size() != program_shape.parameters_size()) {
return InvalidArgument(
- "Invalid number of arguments for computation: expected %d, got %zu.",
+ "Invalid number of arguments for computation: expected %d, got %u.",
program_shape.parameters_size(), argument_layouts.size());
}
@@ -158,7 +159,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
- tensorflow::gtl::optional<const OpMetadata*> metadata =
+ absl::optional<const OpMetadata*> metadata =
ParameterMetadata(computation, /*parameter_number=*/i);
auto metadata_string = [&metadata]() -> string {
if (!metadata.has_value()) {
@@ -167,16 +168,15 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
CHECK(metadata.value() != nullptr);
const OpMetadata& m = *metadata.value();
if (!m.source_file().empty()) {
- return tensorflow::strings::Printf(
- " (%s:%d)", m.source_file().c_str(), m.source_line());
+ return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
}
return "";
};
return InvalidArgument(
"Invalid argument shape for argument %d%s, expected %s, got %s.", i,
- metadata_string().c_str(),
- ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(argument_shape).c_str());
+ metadata_string(),
+ ShapeUtil::HumanString(program_shape.parameters(i)),
+ ShapeUtil::HumanString(argument_shape));
}
}
if (build_options.result_layout() != nullptr) {
@@ -214,7 +214,7 @@ StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
if (replica_number >= buffers.size()) {
return InvalidArgument(
- "replica_number %d out of range; must be less than num_replicas = %zu.",
+ "replica_number %d out of range; must be less than num_replicas = %u.",
replica_number, buffers.size());
}
return buffers[replica_number];
diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc
index c742d35a7b..e1f56727bd 100644
--- a/tensorflow/compiler/xla/service/logical_buffer.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer.cc
@@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {}
string LogicalBuffer::ToString() const {
string color_string;
if (has_color()) {
- color_string = tensorflow::strings::StrCat(" @", color().value());
+ color_string = absl::StrCat(" @", color().value());
}
- return tensorflow::strings::StrCat(instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "](#", id(), color_string, ")");
+ return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
+ "](#", id(), color_string, ")");
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index d631fb5ee4..eaa09591b7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
@@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
const ShapeIndex& index) {
CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
logical_buffers_.emplace_back(
- MakeUnique<LogicalBuffer>(instruction, index, next_buffer_id_));
+ absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
output_buffers_[std::make_pair(instruction, index)] =
logical_buffers_.back().get();
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 0019cd7254..4c8cb7d379 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
- tensorflow::StringPiece name() const override {
- return "multi_output_fusion";
- }
+ absl::string_view name() const override { return "multi_output_fusion"; }
// Run multi-output fusion on the given module. Returns whether the module
// was changed.
@@ -104,17 +102,17 @@ class MultiOutputFusion : public HloPassInterface {
// InstructionFusion instead.
virtual bool DoProducerConsumerMultiOutputFusion();
- private:
- // Update the internal data structures after instr1 and instr2 are fused into
- // one fusion instruction.
- void Update(HloInstruction* instr1, HloInstruction* instr2);
-
// Optimization fuel is a compiler debugging technique that makes an
// optimization pass stop what it is doing after having made N changes to the
// program, where N is the fuel. By varying N, this can be used to find the
// first single change that makes a test fail.
int64 fuel_;
+ private:
+ // Update the internal data structures after instr1 and instr2 are fused into
+ // one fusion instruction.
+ void Update(HloInstruction* instr1, HloInstruction* instr2);
+
// Computation for the pass.
HloComputation* computation_;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index f6e7578a89..bd8fb17a23 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -52,8 +53,8 @@ NameUniquer::NameUniquer(const string& separator) {
return result;
}
-string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
- string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix));
+string NameUniquer::GetUniqueName(absl::string_view prefix) {
+ string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix));
// Strip away numeric suffix (if any). Only recognize separator if it is in
// the middle of the name.
@@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
if (separator_index != string::npos && (separator_index > 0) &&
(separator_index < root.size() - 1)) {
string after_suffix = root.substr(separator_index + 1);
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
has_numeric_suffix = true;
// Remove numeric suffix from root.
root = root.substr(0, separator_index);
+ } else {
+ // absl::SimpleAtoi may modify numeric_suffix even if it returns false.
+ numeric_suffix = 0;
}
}
SequentialIdGenerator& id_generator = generated_names_[root];
numeric_suffix = id_generator.RegisterId(numeric_suffix);
if (numeric_suffix == 0) {
- return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0)
- : root;
+ return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root;
}
- tensorflow::strings::StrAppend(&root, separator_, numeric_suffix);
+ absl::StrAppend(&root, separator_, numeric_suffix);
return root;
}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 4423d61069..6dd89c240f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@@ -38,7 +38,7 @@ class NameUniquer {
// Get a sanitized unique name in a string, with an optional prefix for
// convenience.
- string GetUniqueName(tensorflow::StringPiece prefix = "");
+ string GetUniqueName(absl::string_view prefix = "");
// Sanitizes and returns the name. Unallowed characters will be replaced with
// '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index ac6ea4c72f..4869db79e7 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -622,7 +622,7 @@ template <typename Previous>
class HloInstructionPatternNameImpl {
public:
explicit HloInstructionPatternNameImpl(const Previous& previous,
- tensorflow::StringPiece name)
+ absl::string_view name)
: previous_(previous), name_(name) {}
bool Match(const ::xla::HloInstruction* inst) const {
@@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl {
private:
Previous previous_;
- tensorflow::StringPiece name_;
+ absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
@@ -784,7 +784,7 @@ class HloInstructionPattern {
// Modifies the pattern to match only if the instruction has the given name.
HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(tensorflow::StringPiece name) const {
+ WithName(absl::string_view name) const {
return HloInstructionPattern<HloInstructionType,
HloInstructionPatternNameImpl<Impl>>(
HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
@@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) {
}
XLA_NULLOP_PATTERN(Constant)
XLA_NULLOP_PATTERN(Parameter)
+XLA_NULLOP_PATTERN(Iota)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 39fe3c7835..ae1e13d8a6 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -19,20 +19,19 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
-using tensorflow::str_util::Lowercase;
-
// Minimum supported CUDA compute capability is 3.5.
constexpr int kMinCudaComputeCapabilityMajor = 3;
constexpr int kMinCudaComputeCapabilityMinor = 5;
@@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter";
namespace {
string CanonicalPlatformName(const string& name) {
- string platform_str = Lowercase(name);
+ string platform_str = absl::AsciiStrToLower(name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
@@ -94,12 +93,12 @@ PlatformUtil::GetSupportedPlatforms() {
}
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform found: %s",
- platforms_string.c_str());
+ platforms_string);
}
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
@@ -110,21 +109,21 @@ PlatformUtil::GetSupportedPlatforms() {
return platforms[0];
} else if (platforms.size() == 2) {
for (int i = 0; i < 2; i++) {
- if (Lowercase(platforms[i]->Name()) == kInterpreter &&
- Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
+ if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
+ absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
return platforms[1 - i];
}
}
}
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform (except for the "
"interpreter platform) found: %s",
- platforms_string.c_str());
+ platforms_string);
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
@@ -132,11 +131,11 @@ PlatformUtil::GetSupportedPlatforms() {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) == platform_str) {
+ if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
return platform;
}
}
- return InvalidArgument("platform %s not found", platform_name.c_str());
+ return InvalidArgument("platform %s not found", platform_name);
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
@@ -146,23 +145,23 @@ PlatformUtil::GetSupportedPlatforms() {
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
std::vector<se::Platform*> matched;
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) != platform_name) {
+ if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
matched.push_back(platform);
}
}
if (matched.empty()) {
return InvalidArgument("unable to find platform that is not %s",
- platform_name.c_str());
+ platform_name);
}
if (matched.size() == 1) {
return matched[0];
}
- string matched_string = tensorflow::str_util::Join(
+ string matched_string = absl::StrJoin(
matched, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"found multiple platforms %s, but expected one platform except for %s",
- matched_string.c_str(), platform_name.c_str());
+ matched_string, platform_name);
}
// Returns whether the device underlying the given StreamExecutor is supported
@@ -193,7 +192,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) {
PlatformUtil::GetStreamExecutors(se::Platform* platform) {
int device_count = platform->VisibleDeviceCount();
if (device_count <= 0) {
- return NotFound("no %s devices found", platform->Name().c_str());
+ return NotFound("no %s devices found", platform->Name());
}
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
@@ -232,7 +231,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s",
- platform->Name().c_str());
+ platform->Name());
}
return stream_executors;
}
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index afde3cf95c..256b231e3a 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface {
~ReducePrecisionInsertion() override{};
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "reduce-precision-insertion";
}
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index ca86c5d13e..4df746fca9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -38,6 +38,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
+
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -374,7 +376,7 @@ StatusOr<bool> TryReshapeMoveOnCandidates(
removed = false;
for (auto operand : nontrivial_operands) {
- if (c_any_of(operand->users(), [&](HloInstruction* user) {
+ if (absl::c_any_of(operand->users(), [&](HloInstruction* user) {
return !reshape_candidates->count(user);
})) {
for (auto* user : operand->users()) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1f59e3b314..1e86a0823a 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -26,7 +26,7 @@ namespace xla {
// them inputward also.
class ReshapeMover : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "reshape-mover"; }
+ absl::string_view name() const override { return "reshape-mover"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index ccb9fb3e3a..a395dd5333 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -28,13 +28,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using ReshapeMoverTest = HloVerifiedTestBase;
+
+namespace op = xla::testing::opcode_matchers;
+
+class ReshapeMoverTest : public HloVerifiedTestBase {
+ public:
+ ReshapeMoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 45ca731153..2077b57c05 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/scatter_expander.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -92,7 +93,7 @@ static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
permutation.reserve(updates_rank);
for (int64 i = 0; i < updates_rank; ++i) {
- bool is_scatter_dim = !c_binary_search(update_window_dims, i);
+ bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
if (is_scatter_dim) {
permutation.push_back(i);
}
@@ -290,7 +291,7 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
return Unimplemented(
"Scatter operations with more than 2147483647 scatter indices are not "
"supported. This error occurred for %s.",
- scatter->ToString().c_str());
+ scatter->ToString());
}
// Canonicalize the scatter_indices, after which the size of its most-major
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 8f735e877d..14f062c89c 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -22,7 +22,7 @@ namespace xla {
class ScatterExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "scatter_expander"; }
+ absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 1dbf540d13..e10c1d9927 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -20,10 +20,12 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -46,8 +48,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -55,13 +55,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/ptr_util.h"
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+using absl::StrFormat;
+
// Records the arguments used to invoke a computation in an HloSnapshot proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
@@ -148,19 +147,19 @@ Service::Service(const ServiceOptions& options,
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
<< "Requested more replicas than there are devices.";
}
- LOG(INFO) << Printf(
+ LOG(INFO) << StrFormat(
"XLA service %p executing computations on platform %s. Devices:", this,
- execute_backend_->platform()->Name().c_str());
+ execute_backend_->platform()->Name());
for (int i = 0; i < execute_backend_->device_count(); ++i) {
if (execute_backend_->device_ordinal_supported(i)) {
se::StreamExecutor* executor =
execute_backend_->stream_executor(i).ValueOrDie();
const auto& description = executor->GetDeviceDescription();
- LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
- description.name().c_str(),
- description.platform_version().c_str());
+ LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
+ description.name(),
+ description.platform_version());
} else {
- LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
+ LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i);
}
}
} else {
@@ -200,8 +199,8 @@ Status Service::ValidateResultShape(const Shape& client_shape,
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
"with result shape %s",
- ShapeUtil::HumanStringWithLayout(client_shape).c_str(),
- ShapeUtil::HumanString(result_shape).c_str());
+ ShapeUtil::HumanStringWithLayout(client_shape),
+ ShapeUtil::HumanString(result_shape));
}
return Status::OK();
}
@@ -231,9 +230,9 @@ Service::ResolveAndValidateArguments(
return InvalidArgument(
"argument %lu is on device %s:%d but computation will be executed "
"on device %s",
- i, shaped_buffer->platform()->Name().c_str(),
+ i, shaped_buffer->platform()->Name(),
shaped_buffer->device_ordinal(),
- execute_backend_->device_name(replica_device_ordinal).c_str());
+ execute_backend_->device_name(replica_device_ordinal));
}
replicated_arguments[replica].push_back(shaped_buffer);
}
@@ -245,11 +244,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options) {
- auto config = MakeUnique<HloModuleConfig>(program_shape);
+ auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
- return InvalidArgument("computation takes %d parameters, but %zu given",
+ return InvalidArgument("computation takes %d parameters, but %u given",
program_shape.parameters_size(),
argument_shapes.size());
}
@@ -261,8 +260,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
return InvalidArgument(
"Argument does not match shape of computation parameter %d: want "
"%s, got %s",
- i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(*argument_shapes[i]).c_str());
+ i, ShapeUtil::HumanString(program_shape.parameters(i)),
+ ShapeUtil::HumanString(*argument_shapes[i]));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -314,7 +313,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf("BuildExecutable on service %p", this);
+ VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set.
std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
@@ -326,12 +325,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
if (directory_path.empty() && execution_directory_path.empty()) {
continue;
}
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
if (!directory_path.empty()) {
- string filename =
- Printf("computation_%lld__%s", module_protos[i]->id(),
- module_protos[i]->entry_computation_name().c_str());
+ string filename = StrFormat("computation_%d__%s", module_protos[i]->id(),
+ module_protos[i]->entry_computation_name());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
}
@@ -409,7 +407,8 @@ Service::ExecuteParallelAndRegisterResult(
streams.push_back(std::move(stream));
if (replica == 0 && profile != nullptr) {
- timers.push_back(MakeUnique<se::Timer>(streams.back()->parent()));
+ timers.push_back(
+ absl::make_unique<se::Timer>(streams.back()->parent()));
streams.back()
->InitTimer(timers.back().get())
.ThenStartTimer(timers.back().get());
@@ -453,8 +452,8 @@ Service::ExecuteParallelAndRegisterResult(
for (int64 i = 0; i < streams.size(); ++i) {
Status block_status = streams[i]->BlockHostUntilDone();
if (!block_status.ok()) {
- return InternalError("failed to complete execution for stream %lld: %s",
- i, block_status.error_message().c_str());
+ return InternalError("failed to complete execution for stream %d: %s", i,
+ block_status.error_message());
}
}
@@ -579,7 +578,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
if (requests_size > 1 && execution_options.device_handles_size() > 1) {
return InvalidArgument(
"Parallel requests with multiple device handles is not supported. "
- "Found %lld parallel requests, with request %lld containing %d device "
+ "Found %d parallel requests, with request %d containing %d device "
"handles.",
requests_size, request_index, execution_options.device_handles_size());
}
@@ -744,8 +743,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
}
if (available_device_count < arg->device_count() * replica_count) {
return ResourceExhausted(
- "Requested device count (%lld) exceeds the number of available devices "
- "on the target (%lld)",
+ "Requested device count (%d) exceeds the number of available devices "
+ "on the target (%d)",
arg->device_count(), available_device_count);
}
@@ -795,12 +794,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf(
+ VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this,
- module_proto.name().c_str());
+ module_proto.name());
// Dump computation proto state if flag is set.
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
const string& directory_path =
module_config->debug_options().xla_dump_computations_to();
const string& execution_directory_path =
@@ -808,8 +807,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
if (!directory_path.empty() || !execution_directory_path.empty()) {
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
if (!directory_path.empty()) {
- string filename = Printf("computation_%lld__%s", module_proto.id(),
- module_proto.entry_computation_name().c_str());
+ string filename = StrFormat("computation_%d__%s", module_proto.id(),
+ module_proto.entry_computation_name());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
}
@@ -954,7 +953,7 @@ namespace {
// shape and DeviceMemoryBase values of the clone are identical to the original.
std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
const ShapedBuffer& shaped_buffer, int device_ordinal) {
- auto clone = MakeUnique<ShapedBuffer>(
+ auto clone = absl::make_unique<ShapedBuffer>(
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(),
shaped_buffer.platform(), device_ordinal);
clone->buffers() = shaped_buffer.buffers();
@@ -1009,8 +1008,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
"%s",
StrCat("The replica_id=", arg->replica_id(),
" on TransferToInfeedRequest not in range [0, replica_count=",
- replica_count, ").")
- .c_str());
+ replica_count, ")."));
}
se::StreamExecutor* executor;
@@ -1036,8 +1034,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
- "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
- "%lld)",
+ "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
arg->replica_id(), replica_count);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index cc1ec1704e..f5217c5a11 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -21,6 +21,11 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -28,32 +33,26 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
-using tensorflow::str_util::Join;
-using tensorflow::strings::Printf;
-
namespace xla {
-
namespace {
+using absl::StrFormat;
+using absl::StrJoin;
+
// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+Status ExpectArray(const Shape& shape, absl::string_view op_type) {
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument("Expected array argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
+ string(op_type), ShapeUtil::HumanString(shape));
}
return Status::OK();
}
@@ -65,7 +64,7 @@ Status VerifyReducerShape(
int64 inputs) {
if (reducer_shape.parameters_size() != inputs * 2) {
return InvalidArgument(
- "Reduction function must take %lld parameters, but "
+ "Reduction function must take %d parameters, but "
"takes %d parameter(s).",
inputs * 2, reducer_shape.parameters_size());
}
@@ -75,7 +74,7 @@ Status VerifyReducerShape(
if (ShapeUtil::IsArray(accumulator_shape)) {
if (inputs != 1) {
return InvalidArgument(
- "Reduction function must produce a tuple with %lld elements, but "
+ "Reduction function must produce a tuple with %d elements, but "
"produces a scalar",
inputs);
}
@@ -83,8 +82,8 @@ Status VerifyReducerShape(
} else if (ShapeUtil::IsTuple(accumulator_shape)) {
if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
return InvalidArgument(
- "Reduction function must produce a tuple with %lld elements, but has "
- "%lld elements",
+ "Reduction function must produce a tuple with %d elements, but has "
+ "%d elements",
inputs, ShapeUtil::TupleElementCount(accumulator_shape));
}
for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
@@ -94,7 +93,7 @@ Status VerifyReducerShape(
return InvalidArgument(
"Reduction function must produce a scalar or tuple of scalars, but has "
"shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ ShapeUtil::HumanString(accumulator_shape));
}
for (const Shape* element_shape : accumulator_subshapes) {
@@ -102,7 +101,7 @@ Status VerifyReducerShape(
return InvalidArgument(
"Reduction function must return a scalar or tuple of scalars but "
"returns shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ ShapeUtil::HumanString(accumulator_shape));
}
}
@@ -113,19 +112,19 @@ Status VerifyReducerShape(
if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
reducer_shape.parameters(i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape differs from the "
+ "Reduction function's %d-th parameter shape differs from the "
"result shape: %s vs %s",
- i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]));
}
// Check that init_value's shapes are suitable for reducer_shape.
if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
*init_value_shapes[i])) {
return InvalidArgument(
- "Reduction function's accumulator shape at index %lld differs from "
+ "Reduction function's accumulator shape at index %d differs from "
"the init_value shape: %s vs %s",
- i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(),
- ShapeUtil::HumanString(*init_value_shapes[i]).c_str());
+ i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
+ ShapeUtil::HumanString(*init_value_shapes[i]));
}
// Check that the inputs can be passed in as the non-accumulator arguments.
const Shape input_element_shape =
@@ -133,11 +132,11 @@ Status VerifyReducerShape(
if (!ShapeUtil::CompatibleIgnoringFpPrecision(
input_element_shape, reducer_shape.parameters(inputs + i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape differs from the "
+ "Reduction function's %d-th parameter shape differs from the "
"input type element type: %s vs %s",
inputs + i,
- ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
- ShapeUtil::HumanString(input_element_shape).c_str());
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
+ ShapeUtil::HumanString(input_element_shape));
}
// Check that the accumulator and inputs to the reducer function match.
// If the accumulator is scalar, it must have the same type as the inputs
@@ -147,11 +146,11 @@ Status VerifyReducerShape(
if (!ShapeUtil::CompatibleIgnoringFpPrecision(
*accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape must "
+ "Reduction function's %d-th parameter shape must "
"match the result shape, but got %s vs %s.",
inputs + i,
- ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
- ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]));
}
}
@@ -164,7 +163,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
bool allow_negative_padding) {
if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
return InvalidArgument(
- "Window has dimension %d but base shape has dimension %lld.",
+ "Window has dimension %d but base shape has dimension %d.",
window.dimensions_size(), ShapeUtil::Rank(base_shape));
}
@@ -173,29 +172,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const auto& dim = window.dimensions(i);
if (dim.size() <= 0) {
return InvalidArgument("Window %s has a non-positive dimension.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.stride() <= 0) {
return InvalidArgument("Window %s has a non-positive stride.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (!allow_negative_padding && dim.padding_low() < 0) {
return InvalidArgument("Window %s has a negative low padding.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (!allow_negative_padding && dim.padding_high() < 0) {
return InvalidArgument("Window %s has a negative high padding.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.base_dilation() < 1) {
return InvalidArgument(
"Window %s has a non-positive base area dilation factor.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.window_dilation() < 1) {
return InvalidArgument(
"Window %s has a non-positive window dilation factor.",
- window.DebugString().c_str());
+ window.DebugString());
}
const int64 dilated_base = window_util::DilatedBound(
@@ -233,11 +232,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
switch (opcode) {
case HloOpcode::kFloor:
case HloOpcode::kCeil:
+ case HloOpcode::kRoundNearestAfz:
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating for floor/ceil "
- "operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "Expected element type in shape to be floating for %s operation; "
+ "got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
return shape;
case HloOpcode::kCos:
@@ -250,9 +250,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (!ShapeUtil::ElementIsFloating(shape) &&
!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating or complex for "
- "sin/cos/exp/log/tanh operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "Expected element type in shape to be floating or complex for %s "
+ "operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
return shape;
case HloOpcode::kReal:
@@ -264,19 +264,47 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
} else {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
- "real/imag operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
shape, primitive_util::ComplexComponentType(shape.element_type()));
+ } else if (ShapeUtil::ElementIsSigned(shape)) {
+ return shape;
+ } else {
+ return InvalidArgument(
+ "Expected element type in shape to be floating or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
- return shape;
case HloOpcode::kClz:
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Expected an integral element type in argument to Clz "
+ "operation; got %s.",
+ PrimitiveType_Name(shape.element_type()));
+ }
+ return shape;
case HloOpcode::kNegate:
- case HloOpcode::kRoundNearestAfz:
+ if (!ShapeUtil::ElementIsIntegral(shape) &&
+ !ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be integral, floating or "
+ "complex for %s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
+ }
+ return shape;
case HloOpcode::kSign:
+ if (!ShapeUtil::ElementIsSigned(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be signed or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
+ }
return shape;
case HloOpcode::kNot:
@@ -285,7 +313,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Expected pred or an integral element type in argument to Not "
"operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return shape;
@@ -295,14 +323,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
"Expected element type in shape to be floating "
"point for IsFinite "
"operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return ShapeUtil::ChangeElementType(shape, PRED);
default:
return InvalidArgument(
"Unknown operation for unary shape inference: \"%s\".",
- HloOpcodeString(opcode).c_str());
+ HloOpcodeString(opcode));
}
}
@@ -313,7 +341,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument("Concatenate expects at least one argument.");
}
if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
- return InvalidArgument("Concatenate dimension out of bounds: %lld.",
+ return InvalidArgument("Concatenate dimension out of bounds: %d.",
dimension);
}
const Shape* arg_shape = nullptr;
@@ -327,17 +355,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
return InvalidArgument(
- "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
+ "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
"(%s).",
- ShapeUtil::Rank(*arg_shape),
- ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
- ShapeUtil::HumanString(*shape).c_str());
+ ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape),
+ ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
return InvalidArgument(
"Cannot concatenate arrays with different element types: %s vs %s.",
- PrimitiveType_Name(arg_shape->element_type()).c_str(),
- PrimitiveType_Name(shape->element_type()).c_str());
+ PrimitiveType_Name(arg_shape->element_type()),
+ PrimitiveType_Name(shape->element_type()));
}
for (int64 dimension_number = 0;
dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
@@ -350,9 +377,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Cannot concatenate arrays that differ in dimensions other than "
"the one being concatenated (the other array dimensions must be "
- "the same): %s vs %s in dimension %lld.",
- ShapeUtil::HumanString(*arg_shape).c_str(),
- ShapeUtil::HumanString(*shape).c_str(), dimension);
+ "the same): %s vs %s in dimension %d.",
+ ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
+ dimension);
}
}
element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
@@ -384,8 +411,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
!primitive_util::IsComplexType(new_element_type)) {
return Unimplemented(
"Conversion from complex to real type %s => %s is not implemented.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (!ShapeUtil::IsArray(operand_shape) ||
!primitive_util::IsArrayType(new_element_type)) {
@@ -394,8 +421,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// are valid. For now we just reject them, though.
return InvalidArgument(
"Convert does not allow non-arrays, so cannot convert from %s to %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
@@ -407,8 +434,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (primitive_util::IsComplexType(old_element_type) !=
primitive_util::IsComplexType(new_element_type)) {
return InvalidArgument("Conversion from complex to real type %s => %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (!ShapeUtil::IsArray(operand_shape) ||
!primitive_util::IsArrayType(new_element_type)) {
@@ -417,15 +444,15 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// are valid. For now we just reject them, though.
return InvalidArgument(
"Cannot convert from or to tuple type; requested conversion: %s => %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (primitive_util::BitWidth(old_element_type) !=
primitive_util::BitWidth(new_element_type)) {
return InvalidArgument(
"Cannot bitcast types with different bit-widths: %s => %s.",
- PrimitiveType_Name(old_element_type).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ PrimitiveType_Name(old_element_type),
+ PrimitiveType_Name(new_element_type));
}
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
@@ -438,7 +465,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Expected element type in shape to be floating point for "
"ReducePrecision operation; got %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (exponent_bits < 1) {
// One exponent bit is necessary to distinguish 0 from infinity. Having
@@ -470,21 +497,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"The rank of the operand and the padding configuration do not match: "
"%s vs %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- padding_config.ShortDebugString().c_str());
+ ShapeUtil::HumanString(operand_shape),
+ padding_config.ShortDebugString());
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
padding_value_shape)) {
return InvalidArgument(
"The element types of the operands to Pad do not match.");
}
+ if (absl::c_any_of(padding_config.dimensions(),
+ [](const PaddingConfig::PaddingConfigDimension& p) {
+ return p.interior_padding() < 0;
+ })) {
+ return InvalidArgument("Interior padding cannot be negative: %s",
+ padding_config.ShortDebugString());
+ }
+
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
- dimensions[i] = operand_shape.dimensions(i) +
- padding_config.dimensions(i).edge_padding_low() +
- padding_config.dimensions(i).edge_padding_high() +
+ const auto& p = padding_config.dimensions(i);
+ dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
+ p.edge_padding_high() +
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
- padding_config.dimensions(i).interior_padding();
+ p.interior_padding();
}
return ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
@@ -538,7 +573,7 @@ Status ValidateDotDimensionNumbers(
!dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
rhs_batch_dimensions)) {
return InvalidArgument("A dimension number is out of range in Dot: %s.",
- dimension_numbers.DebugString().c_str());
+ dimension_numbers.DebugString());
}
// Check that dimension numbers are unique.
@@ -556,7 +591,7 @@ Status ValidateDotDimensionNumbers(
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
!dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
return InvalidArgument("A dimension number is not unique in Dot: %s.",
- dimension_numbers.DebugString().c_str());
+ dimension_numbers.DebugString());
}
// Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
@@ -601,14 +636,13 @@ Status ValidateDotDimensionNumbers(
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
- string message = tensorflow::strings::Printf(
- "Cannot infer shape for dot operation: %s <dot> %s.",
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ string message =
+ StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
if (!addendum.empty()) {
message += " " + addendum;
}
- return InvalidArgument("%s", message.c_str());
+ return InvalidArgument("%s", message);
};
// Check if both element types are the same.
@@ -704,9 +738,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
- HloOpcodeString(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs));
}
}
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
@@ -721,14 +754,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// the user to provide an explicit broadcast dimension in this case.
// See b/25177275 for more details.
return InvalidArgument("Automatic shape inference not supported: %s and %s",
- ShapeUtil::HumanString(smaller_shape).c_str(),
- ShapeUtil::HumanString(larger_shape).c_str());
+ ShapeUtil::HumanString(smaller_shape),
+ ShapeUtil::HumanString(larger_shape));
} else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
return InvalidArgument(
"Size of broadcast_dimensions has to match lower-rank operand's "
"rank; "
- " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
- "%zu.",
+ " lower-rank operand's rank is %d, size of broadcast_dimensions is "
+ "%u.",
ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
}
@@ -778,12 +811,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
int64 dimension_to_match = broadcast_dimensions.at(i);
if (dimension_to_match < 0) {
return InvalidArgument(
- "Broadcast dimension number (%lld) cannot be negative.",
+ "Broadcast dimension number (%d) cannot be negative.",
dimension_to_match);
}
if (dimension_to_match >= larger_shape.dimensions_size()) {
return InvalidArgument(
- "Broadcast dimension number (%lld) too large; higher-rank "
+ "Broadcast dimension number (%d) too large; higher-rank "
"operand has rank %d.",
dimension_to_match, larger_shape.dimensions_size());
}
@@ -795,16 +828,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (small_dimension_size != large_dimension_size &&
small_dimension_size != 1 && large_dimension_size != 1) {
return InvalidArgument(
- "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i,
+ "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
small_dimension_size, large_dimension_size,
- ShapeUtil::HumanString(smaller_shape).c_str(),
- ShapeUtil::HumanString(larger_shape).c_str());
+ ShapeUtil::HumanString(smaller_shape),
+ ShapeUtil::HumanString(larger_shape));
}
// Make sure the broadcast dimensions are listed in a strictly increasing
// order.
if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
return InvalidArgument(
- "Broadcast dimensions order is wrong: %lld comes after %lld.",
+ "Broadcast dimensions order is wrong: %d comes after %d.",
dimension_to_match, broadcast_dimensions.at(i - 1));
}
@@ -823,8 +856,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Binary op %s with different element types: %s and %s.",
- HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
@@ -874,20 +907,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << StrFormat(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
- HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str(),
- Join(broadcast_dimensions, ", ").c_str());
+ HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(
- ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(
- ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -909,7 +939,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected element type in shape to be floating for complex compose "
"operation; got %s.",
- PrimitiveType_Name(lhs.element_type()).c_str());
+ PrimitiveType_Name(lhs.element_type()));
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
@@ -928,7 +958,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected pred or integral type in argument to and/or operation; "
"got %s.",
- PrimitiveType_Name(lhs.element_type()).c_str());
+ PrimitiveType_Name(lhs.element_type()));
}
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
@@ -946,8 +976,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
default:
return Unimplemented(
"Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
- HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(),
- rhs.ShortDebugString().c_str());
+ HloOpcodeString(opcode), lhs.ShortDebugString(),
+ rhs.ShortDebugString());
}
}
@@ -970,8 +1000,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case HloOpcode::kTupleSelect:
return InferTupleSelectShape(lhs, rhs, ehs);
default:
- return InvalidArgument("Unknown operation %s.",
- HloOpcodeString(opcode).c_str());
+ return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
}
}
@@ -1010,8 +1039,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Sort keys and values dimensions must match. "
"Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(*operand_shapes[0]).c_str(),
- ShapeUtil::HumanString(*operand_shapes[1]).c_str());
+ ShapeUtil::HumanString(*operand_shapes[0]),
+ ShapeUtil::HumanString(*operand_shapes[1]));
}
return ShapeUtil::MakeTupleShape(
{*operand_shapes[0], *operand_shapes[1]});
@@ -1019,8 +1048,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument("Unexpected number of operands for sort");
}
default:
- return InvalidArgument("Unknown operation %s.",
- HloOpcodeString(opcode).c_str());
+ return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
}
}
@@ -1058,7 +1086,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Map operation requires all operands to have the same shape; got: "
"%s.",
- Join(pieces, ", ").c_str());
+ StrJoin(pieces, ", "));
}
// Check that dimensions.size == arg_shape.dimensions_size() (we currently
@@ -1066,7 +1094,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (dimensions.size() != arg_shape->dimensions_size()) {
return InvalidArgument(
"Map applied to a subset of dimensions currently not supported: "
- "arg_dimension_size: %d, requested_map_dimensions_size: %zu.",
+ "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
arg_shape->dimensions_size(), dimensions.size());
}
@@ -1075,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (dimensions[i] != i) {
return InvalidArgument(
"Map requires monotonically increasing dimension numbers; got: %s.",
- Join(dimensions, ", ").c_str());
+ StrJoin(dimensions, ", "));
}
}
@@ -1083,7 +1111,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (arg_shapes.size() != to_apply.parameters_size()) {
return InvalidArgument(
"Map applied function arity must match number of arguments; got: "
- "arity: %d, arguments: %zu.",
+ "arity: %d, arguments: %u.",
to_apply.parameters_size(), arg_shapes.size());
}
@@ -1092,7 +1120,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::IsScalar(output_shape)) {
return InvalidArgument(
"Mapped computation's result has to be a scalar; got: %s.",
- ShapeUtil::HumanString(output_shape).c_str());
+ ShapeUtil::HumanString(output_shape));
}
for (int i = 0; i < to_apply.parameters_size(); ++i) {
@@ -1102,7 +1130,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Mapped computation's parameter has to be a scalar; "
"got parameter %d shape: %s.",
- i, ShapeUtil::HumanString(parameter_shape).c_str());
+ i, ShapeUtil::HumanString(parameter_shape));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
@@ -1110,8 +1138,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Mapped computation's parameter type has to match argument element "
"type; got parameter %d shape: %s, argument shape: %s.",
- i, ShapeUtil::HumanString(parameter_shape).c_str(),
- ShapeUtil::HumanString(*arg_shape).c_str());
+ i, ShapeUtil::HumanString(parameter_shape),
+ ShapeUtil::HumanString(*arg_shape));
}
}
@@ -1140,35 +1168,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-training to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-training to "
- "be a non-negative number, got %lld.",
+ "be a non-negative number, got %d.",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
- "batch-norm-training to be at least 1; got %lld.",
+ "batch-norm-training to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-training must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-training must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
@@ -1176,7 +1204,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-training must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
@@ -1185,8 +1213,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-training, "
"but the shape of offset factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(offset_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(offset_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1195,8 +1223,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-training, "
"but the shape of scale factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1206,16 +1234,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
@@ -1250,35 +1278,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-inference to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to "
- "be a non-negative number, got %lld.",
+ "be a non-negative number, got %d.",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
- "batch-norm-inference to be at least 1; got %lld.",
+ "batch-norm-inference to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-inference must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-inference must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
@@ -1286,7 +1314,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-inference must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
@@ -1296,8 +1324,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of offset factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(offset_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(offset_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1307,8 +1335,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of scale factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
@@ -1318,8 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of mean is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
@@ -1329,8 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of variance is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(variance_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(variance_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1340,32 +1368,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
- "but the size of mean is %lld "
- "and the feature count is %lld.",
+ "but the size of mean is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
- "but the size of variance is %lld "
- "and the feature count is %lld.",
+ "but the size of variance is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
}
@@ -1395,36 +1423,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-grad to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
return InvalidArgument(
"Expected operand_shape of batch-norm-grad to have the same rank as"
- " output_grad_shape; got rank(oprand_shape) %lld, and"
- " rank(output_grad_shape) %lld.",
+ " output_grad_shape; got rank(oprand_shape) %d, and"
+ " rank(output_grad_shape) %d.",
ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
}
if (ShapeUtil::Rank(mean_shape) != 1) {
return InvalidArgument(
"Mean input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(mean_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
if (ShapeUtil::Rank(var_shape) != 1) {
return InvalidArgument(
"Var input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(var_shape));
}
@@ -1432,14 +1460,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-grad must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
return InvalidArgument(
"The output_grad to batch-norm-grad must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(output_grad_shape.element_type()).c_str());
+ PrimitiveType_Name(output_grad_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
@@ -1448,8 +1476,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of output_grad is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(output_grad_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(output_grad_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1458,8 +1486,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of scale factor is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
@@ -1468,8 +1496,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
@@ -1478,8 +1506,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1490,24 +1518,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
- "but the size of variance is %lld "
- "and the feature count is %lld.",
+ "but the size of variance is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(var_shape, 0), feature_count);
}
@@ -1517,8 +1545,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::GetDimension(output_grad_shape, i)) {
return InvalidArgument(
"The bounds of operand shape should be the same as output_grad's,"
- "but the bound of operand_shape at dimension %lld is %lld "
- "and the bound of output_grad_shape is %lld.",
+ "but the bound of operand_shape at dimension %d is %d "
+ "and the bound of output_grad_shape is %d.",
i, ShapeUtil::GetDimension(operand_shape, i),
ShapeUtil::GetDimension(output_grad_shape, i));
}
@@ -1537,15 +1565,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Convolution with different element types: %s and %s.",
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
}
if (dnums.input_spatial_dimensions_size() !=
dnums.kernel_spatial_dimensions_size()) {
return InvalidArgument(
"Both arguments to convolution must have same number of dimensions.\n"
"Window: %s",
- window.DebugString().c_str());
+ window.DebugString());
}
const int num_spatial_dims = dnums.input_spatial_dimensions_size();
@@ -1553,19 +1580,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Window must have same number of dimensions as dimension numbers.\n"
"Window: %s\nDimension numbers: %s.",
- window.DebugString().c_str(), dnums.DebugString().c_str());
+ window.DebugString(), dnums.DebugString());
}
const int num_dims = num_spatial_dims + 2;
if (ShapeUtil::Rank(lhs) != num_dims) {
return InvalidArgument(
"The LHS argument to a convolution should have rank %d; lhs: %s.",
- num_dims, ShapeUtil::HumanString(lhs).c_str());
+ num_dims, ShapeUtil::HumanString(lhs));
}
if (ShapeUtil::Rank(rhs) != num_dims) {
return InvalidArgument(
"The RHS argument to a convolution should have rank %d; lhs: %s.",
- num_dims, ShapeUtil::HumanString(lhs).c_str());
+ num_dims, ShapeUtil::HumanString(lhs));
}
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
@@ -1602,26 +1629,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
return InvalidArgument(
"A dimension number is out of range in convolution: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (input_dnums != expected_dnums) {
return InvalidArgument(
"Input dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (window_dnums != expected_dnums) {
return InvalidArgument(
"Window dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (output_dnums != expected_dnums) {
return InvalidArgument(
"Output dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
std::vector<int64> input_spatial_dims(num_spatial_dims);
@@ -1642,13 +1669,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
- "Expected LHS feature dimension (value %lld) to match RHS "
- "input feature dimension * feature_group_count (value %lld); "
+ "Expected LHS feature dimension (value %d) to match RHS "
+ "input feature dimension * feature_group_count (value %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
input_features, kernel_input_features * feature_group_count,
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
}
std::vector<int64> window_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -1660,8 +1687,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"RHS shape: %s\n\t"
"Window: {%s}\n\t"
"Dimension numbers: {%s}.",
- ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
- dnums.ShortDebugString().c_str());
+ ShapeUtil::HumanString(rhs), window.ShortDebugString(),
+ dnums.ShortDebugString());
}
Shape base_shape =
@@ -1687,29 +1714,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const tensorflow::gtl::ArraySlice<int64> fft_length) {
const int64 fft_rank = fft_length.size();
if (fft_rank < 1 || fft_rank > 3) {
- return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank);
+ return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
}
-#define RET_CHECK_RANK(x) \
- if (x.dimensions_size() < fft_rank) { \
- return InvalidArgument( \
- "FFT of rank %lld requires input of at least " \
- "same rank; got input of rank %d", \
- fft_rank, x.dimensions_size()); \
+#define RET_CHECK_RANK(x) \
+ if (x.dimensions_size() < fft_rank) { \
+ return InvalidArgument( \
+ "FFT of rank %d requires input of at least " \
+ "same rank; got input of rank %d", \
+ fft_rank, x.dimensions_size()); \
}
switch (fft_type) {
case FFT:
case IFFT:
if (in.element_type() != C64) {
return InvalidArgument("%s requires C64 input type, found %s.",
- FftType_Name(fft_type).c_str(),
- PrimitiveType_Name(in.element_type()).c_str());
+ FftType_Name(fft_type),
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
return in;
case RFFT: {
if (in.element_type() != F32) {
return InvalidArgument("RFFT requires F32 input type, found %s.",
- PrimitiveType_Name(in.element_type()).c_str());
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
for (int i = 0; i < fft_rank; i++) {
@@ -1717,7 +1744,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[i]) {
return InvalidArgument(
"RFFT requires innermost dimensions match fft_length but "
- "dimension %lld is %lld and should be %lld.",
+ "dimension %d is %d and should be %d.",
in.dimensions_size() - fft_rank + i,
in.dimensions(in.dimensions_size() - fft_rank + i),
fft_length[i]);
@@ -1731,7 +1758,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case IRFFT: {
if (in.element_type() != C64) {
return InvalidArgument("IRFFT requires C64 input type, found %s.",
- PrimitiveType_Name(in.element_type()).c_str());
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
Shape result = ShapeUtil::ComplexComponentShape(in);
@@ -1740,7 +1767,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[i]) {
return InvalidArgument(
"IRFFT requires all but one innermost dimensions match "
- "fft_length, but dimension %lld is %lld and should be %lld.",
+ "fft_length, but dimension %d is %d and should be %d.",
in.dimensions_size() - fft_rank + i,
in.dimensions(in.dimensions_size() - fft_rank + i),
fft_length[i]);
@@ -1750,7 +1777,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[fft_rank - 1] / 2 + 1) {
return InvalidArgument(
"IRFFT requires innermost dimension matches fft_length/2+1, but "
- "dimension %d is %lld and should be %lld.",
+ "dimension %d is %d and should be %d.",
in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
fft_length[fft_rank - 1] / 2 + 1);
}
@@ -1786,18 +1813,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RET_CHECK(split_count > 0);
if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
return InvalidArgument(
- "AllToAll split_dimension %lld is out-of-bounds in shape %s.",
- split_dimension, ShapeUtil::HumanString(shape).c_str());
+ "AllToAll split_dimension %d is out-of-bounds in shape %s.",
+ split_dimension, ShapeUtil::HumanString(shape));
}
if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
return InvalidArgument(
- "AllToAll concat_dimension %lld is out-of-bounds in shape %s.",
- concat_dimension, ShapeUtil::HumanString(shape).c_str());
+ "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
+ concat_dimension, ShapeUtil::HumanString(shape));
}
if (shape.dimensions(split_dimension) % split_count != 0) {
return InvalidArgument(
- "AllToAll split dimension size %lld must be dividable by split_count "
- "%lld.",
+ "AllToAll split dimension size %d must be dividable by split_count "
+ "%d.",
shape.dimensions(split_dimension), split_count);
}
std::vector<int64> new_dimensions(shape.dimensions().begin(),
@@ -1817,14 +1844,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"HLO all-to-all has operands with different shapes: the 0th "
"operand shape %s, but the %dth operand has shape %s.",
- ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i,
- ShapeUtil::HumanString(*operand_shapes[i]).c_str());
+ ShapeUtil::HumanString(*operand_shapes[0]), i,
+ ShapeUtil::HumanString(*operand_shapes[i]));
}
}
return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
}
+/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape));
+ return shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
@@ -1847,9 +1880,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
return InvalidArgument(
"All reduced tensors must have the sime dimension. Tensor 0 has "
- "shape %s, Tensor %lld has shape %s",
- ShapeUtil::HumanString(*reduced_args[0]).c_str(), i,
- ShapeUtil::HumanString(*reduced_args[i]).c_str());
+ "shape %s, Tensor %d has shape %s",
+ ShapeUtil::HumanString(*reduced_args[0]), i,
+ ShapeUtil::HumanString(*reduced_args[i]));
}
}
@@ -1859,9 +1892,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
- return InvalidArgument(
- "Reducing out-of-bounds dimension %lld in shape %s.", dimension,
- ShapeUtil::HumanString(arg).c_str());
+ return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
+ dimension, ShapeUtil::HumanString(arg));
}
}
@@ -1934,16 +1966,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Select function's first parameter shape currently must "
"match the operand element shape, but got %s vs %s.",
- ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
- ShapeUtil::HumanString(operand_element_shape).c_str());
+ ShapeUtil::HumanString(select_shape.parameters(0)),
+ ShapeUtil::HumanString(operand_element_shape));
}
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
select_shape.parameters(1))) {
return InvalidArgument(
"Select function's second parameter shape currently must "
"match the operand element shape, but got %s vs %s.",
- ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(operand_element_shape).c_str());
+ ShapeUtil::HumanString(select_shape.parameters(1)),
+ ShapeUtil::HumanString(operand_element_shape));
}
// Check if the scatter function has a proper shape as a reduction.
@@ -1961,8 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Source shape does not match the shape of window-reduced operand: "
"source(%s), window-reduced operand(%s).",
- ShapeUtil::HumanString(source_shape).c_str(),
- ShapeUtil::HumanString(window_result_shape).c_str());
+ ShapeUtil::HumanString(source_shape),
+ ShapeUtil::HumanString(window_result_shape));
}
return operand_shape;
}
@@ -1975,29 +2007,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
"{%s}; strides: {%s}.",
- message.c_str(), ShapeUtil::HumanString(arg).c_str(),
- Join(starts, ",").c_str(), Join(limits, ",").c_str(),
- Join(strides, ",").c_str());
+ message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
+ StrJoin(limits, ","), StrJoin(strides, ","));
};
TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
- VLOG(2) << tensorflow::strings::Printf(
- "slicing shape %s starts={%s} limits={%s}",
- ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
- Join(limits, ", ").c_str());
+ VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
+ ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
+ StrJoin(limits, ", "));
if (starts.size() != limits.size()) {
- return error(Printf("slice start and limit sizes differ: %zu vs %zu",
- starts.size(), limits.size()));
+ return error(StrFormat("slice start and limit sizes differ: %u vs %u",
+ starts.size(), limits.size()));
}
if (starts.size() != strides.size()) {
- return error(Printf("slice start and strides sizes differ: %zu vs %zu",
- starts.size(), strides.size()));
+ return error(StrFormat("slice start and strides sizes differ: %u vs %u",
+ starts.size(), strides.size()));
}
if (starts.size() != ShapeUtil::Rank(arg)) {
return InvalidArgument(
- "Slice index count does not match argument rank: %zu vs %lld.",
+ "Slice index count does not match argument rank: %u vs %d.",
starts.size(), ShapeUtil::Rank(arg));
}
@@ -2007,27 +2037,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
int64 limit_index = limits[dimension];
int64 stride = strides[dimension];
if (start_index < 0) {
- return InvalidArgument("Negative start index to slice: %lld.",
- start_index);
+ return InvalidArgument("Negative start index to slice: %d.", start_index);
}
if (limit_index > arg.dimensions(dimension)) {
return error(
- Printf("limit index (%lld) must be less than or equal to dimension "
- "size (%lld)",
- limit_index, arg.dimensions(dimension)));
- }
- VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
- start_index);
- VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
- limit_index);
+ StrFormat("limit index (%d) must be less than or equal to dimension "
+ "size (%d)",
+ limit_index, arg.dimensions(dimension)));
+ }
+ VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
+ VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
if (start_index > limit_index) {
return error(
- Printf("limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice with positive stride",
- limit_index, start_index));
+ StrFormat("limit index (%d) must be greater or equal to "
+ "start index (%d) in slice with positive stride",
+ limit_index, start_index));
}
if (stride <= 0) {
- return InvalidArgument("Stride (%lld) must be positive.", stride);
+ return InvalidArgument("Stride (%d) must be positive.", stride);
}
sizes.push_back((limit_index - start_index + stride - 1) / stride);
}
@@ -2042,15 +2069,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << StrFormat(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
- ShapeUtil::HumanString(operand_shape).c_str(),
- ShapeUtil::HumanString(start_indices_shape).c_str(),
- Join(slice_sizes, ", ").c_str());
+ ShapeUtil::HumanString(operand_shape),
+ ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", "));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
return InvalidArgument(
- "Dynamic slice start indices of rank %lld must be rank1.",
+ "Dynamic slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
}
@@ -2062,16 +2088,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
return InvalidArgument(
- "Dynamic slice start number of dimensions %lld (%s) must match rank "
- "%lld of slice input (%s).",
- start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::Rank(operand_shape),
- ShapeUtil::HumanString(operand_shape).c_str());
+ "Dynamic slice start number of dimensions %d (%s) must match rank "
+ "%d of slice input (%s).",
+ start_num_dims, ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
}
if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
- "Dynamic slice index count does not match argument rank: %zu vs %lld.",
+ "Dynamic slice index count does not match argument rank: %u vs %d.",
slice_sizes.size(), ShapeUtil::Rank(operand_shape));
}
@@ -2079,16 +2104,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 input_dim_size = operand_shape.dimensions(dim);
const int64 slice_dim_size = slice_sizes[dim];
if (slice_dim_size < 0) {
- return InvalidArgument("Negative size index to dynamic slice: %lld.",
+ return InvalidArgument("Negative size index to dynamic slice: %d.",
slice_dim_size);
}
if (slice_dim_size > input_dim_size) {
return InvalidArgument(
- "Slice dim size %lld greater than dynamic slice dimension: %lld.",
+ "Slice dim size %d greater than dynamic slice dimension: %d.",
slice_dim_size, input_dim_size);
}
- VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
- slice_dim_size);
+ VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
}
return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
@@ -2104,16 +2128,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
"start indices of dynamic update slice"));
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << StrFormat(
"updating slice of shape %s at dynamic start_indices %s with update "
"shape %s",
- ShapeUtil::HumanString(operand_shape).c_str(),
- ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::HumanString(update_shape).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::HumanString(update_shape));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
return InvalidArgument(
- "Dynamic update slice start indices of rank %lld must be rank1.",
+ "Dynamic update slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
}
@@ -2125,17 +2149,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
return InvalidArgument(
- "Dynamic update slice start number of dimensions %lld (%s) must match "
- "rank %lld of slice input (%s).",
- start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::Rank(operand_shape),
- ShapeUtil::HumanString(operand_shape).c_str());
+ "Dynamic update slice start number of dimensions %d (%s) must match "
+ "rank %d of slice input (%s).",
+ start_num_dims, ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
}
if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
"Dynamic update slice update rank does not match argument rank: "
- "%lld vs %lld.",
+ "%d vs %d.",
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
}
@@ -2144,8 +2167,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Dynamic update slice update element type does not match argument. "
"operand.element_type: %s vs update.element_type: %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str(),
- PrimitiveType_Name(update_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()),
+ PrimitiveType_Name(update_shape.element_type()));
}
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
@@ -2153,16 +2176,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 update_dim_size = update_shape.dimensions(dim);
if (update_dim_size < 0) {
return InvalidArgument(
- "Size index %lld to dynamic update slice must be >= 0.",
+ "Size index %d to dynamic update slice must be >= 0.",
update_dim_size);
}
if (update_dim_size > input_dim_size) {
return InvalidArgument(
- "Update dim size %lld greater than dynamic slice dimension: %lld.",
+ "Update dim size %d greater than dynamic slice dimension: %d.",
update_dim_size, input_dim_size);
}
- VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
- update_dim_size);
+ VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
}
return operand_shape;
@@ -2177,8 +2199,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
for (int64 dimension : dimensions) {
if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
return InvalidArgument(
- "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.",
- dimension, ShapeUtil::HumanString(operand_shape).c_str());
+ "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
+ dimension, ShapeUtil::HumanString(operand_shape));
}
}
return operand_shape;
@@ -2189,14 +2211,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::IsTuple(arg)) {
return InvalidArgument(
"Cannot infer shape: attempting to index into non-tuple: %s.",
- ShapeUtil::HumanString(arg).c_str());
+ ShapeUtil::HumanString(arg));
}
if (index >= arg.tuple_shapes_size()) {
return InvalidArgument(
- "Cannot infer shape: attempt to index out of tuple bounds: %lld "
+ "Cannot infer shape: attempt to index out of tuple bounds: %d "
">= %d in shape %s.",
- index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
+ index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
}
return arg.tuple_shapes(index);
@@ -2216,17 +2238,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
auto shape_string = [&]() {
- return tensorflow::strings::Printf(
- "Condition: %s; body: %s; init: %s.",
- ShapeUtil::HumanString(condition).c_str(),
- ShapeUtil::HumanString(body).c_str(),
- ShapeUtil::HumanString(init).c_str());
+ return StrFormat(
+ "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
+ ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
};
// Check the shapes of computation parameters and return types.
if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
return InvalidArgument("Condition must return a boolean; got %s.",
- shape_string().c_str());
+ shape_string());
}
if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
!ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
@@ -2234,7 +2254,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The parameter of condition and body, the result of the body, and init "
"must all have the same shape; got %s.",
- shape_string().c_str());
+ shape_string());
}
return init;
@@ -2246,7 +2266,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const ProgramShape& false_computation) {
if (!ShapeUtil::ShapeIs(predicate, PRED, {})) {
return InvalidArgument("Predicate must be a boolean; got %s.",
- ShapeUtil::HumanString(predicate).c_str());
+ ShapeUtil::HumanString(predicate));
}
if (true_computation.parameters_size() != 1) {
@@ -2255,15 +2275,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) {
auto true_shape_string = [&]() {
- return tensorflow::strings::Printf(
- "true_operand: %s; true_computation: %s",
- ShapeUtil::HumanString(true_operand).c_str(),
- ShapeUtil::HumanString(true_computation).c_str());
+ return StrFormat("true_operand: %s; true_computation: %s",
+ ShapeUtil::HumanString(true_operand),
+ ShapeUtil::HumanString(true_computation));
};
return InvalidArgument(
"true_operand must match the shape of the only parameter of "
"true_computation: got %s.",
- true_shape_string().c_str());
+ true_shape_string());
}
if (false_computation.parameters_size() != 1) {
@@ -2272,28 +2291,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) {
auto false_shape_string = [&]() {
- return tensorflow::strings::Printf(
- "false_operand: %s; false_computation: %s",
- ShapeUtil::HumanString(false_operand).c_str(),
- ShapeUtil::HumanString(false_computation).c_str());
+ return StrFormat("false_operand: %s; false_computation: %s",
+ ShapeUtil::HumanString(false_operand),
+ ShapeUtil::HumanString(false_computation));
};
return InvalidArgument(
"false_operand must match the shape of the only parameter of "
"false_computation: got %s.",
- false_shape_string().c_str());
+ false_shape_string());
}
if (!ShapeUtil::Compatible(true_computation.result(),
false_computation.result())) {
auto shape_string = [&]() {
- return tensorflow::strings::Printf(
+ return StrFormat(
"true_computation result: %s; false_computation result: %s.",
- ShapeUtil::HumanString(true_computation.result()).c_str(),
- ShapeUtil::HumanString(false_computation.result()).c_str());
+ ShapeUtil::HumanString(true_computation.result()),
+ ShapeUtil::HumanString(false_computation.result()));
};
return InvalidArgument(
"the result of true_computation and false_computation must have the "
"same shape: got %s.",
- shape_string().c_str());
+ shape_string());
}
return true_computation.result();
}
@@ -2303,7 +2321,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
- return InvalidArgument("Broadcast with negative dimension size %lld.",
+ return InvalidArgument("Broadcast with negative dimension size %d.",
size);
}
}
@@ -2328,11 +2346,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
return InvalidArgument(
- "Reshape operation has mismatched element counts: from=%lld (%s) "
- "to=%lld (%s).",
- ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
+ "Reshape operation has mismatched element counts: from=%d (%s) "
+ "to=%d (%s).",
+ ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
ShapeUtil::ElementsIn(inferred_shape),
- ShapeUtil::HumanString(inferred_shape).c_str());
+ ShapeUtil::HumanString(inferred_shape));
}
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2343,7 +2361,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Reshape dimensions [%s] are not a permutation of the operand "
"dimensions (operand shape is %s).",
- Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str());
+ StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
}
return inferred_shape;
@@ -2378,9 +2396,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
- ShapeUtil::HumanString(min).c_str(),
- ShapeUtil::HumanString(operand).c_str(),
- ShapeUtil::HumanString(max).c_str());
+ ShapeUtil::HumanString(min),
+ ShapeUtil::HumanString(operand),
+ ShapeUtil::HumanString(max));
}
if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
ShapeUtil::IsScalar(min)) &&
@@ -2397,9 +2415,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return ShapeUtil::ChangeElementType(min, operand.element_type());
}
}
- return Unimplemented(
- "%s, %s <clamp> %s is not implemented.", min.ShortDebugString().c_str(),
- max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
+ return Unimplemented("%s, %s <clamp> %s is not implemented.",
+ min.ShortDebugString(), max.ShortDebugString(),
+ operand.ShortDebugString());
}
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
@@ -2410,13 +2428,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
return InvalidArgument(
"Operands to select must be the same shape; got %s and %s.",
- ShapeUtil::HumanString(on_true).c_str(),
- ShapeUtil::HumanString(on_false).c_str());
+ ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
}
if (pred.element_type() != PRED) {
return InvalidArgument(
"Select's pred operand must have PRED element type; got %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
ShapeUtil::IsScalar(pred)) {
@@ -2429,7 +2446,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Select operation with non-scalar predicate with dimensionality "
" different from the other operands: %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
}
@@ -2440,18 +2457,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::Compatible(on_true, on_false)) {
return InvalidArgument(
"Operands to tuple-select must be the same shape; got %s and %s.",
- ShapeUtil::HumanString(on_true).c_str(),
- ShapeUtil::HumanString(on_false).c_str());
+ ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
}
if (pred.element_type() != PRED) {
return InvalidArgument(
"TupleSelect's pred operand must have PRED element type; got %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
if (!ShapeUtil::IsScalar(pred)) {
return InvalidArgument(
"TupleSelect operation with non-scalar predicate: %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
return on_true;
}
@@ -2463,15 +2479,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
string argument_shapes =
- Join(arg_shapes, ", ", [](string* out, const Shape* shape) {
- tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
+ StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
+ absl::StrAppend(out, ShapeUtil::HumanString(*shape));
});
return InvalidArgument(
"Call applied function arity must match number of arguments; got: "
- "arity: %d, arguments: %zu; computation signature: %s; argument "
+ "arity: %d, arguments: %u; computation signature: %s; argument "
"shapes: [%s].",
- to_apply.parameters_size(), arg_shapes.size(),
- computation_signature.c_str(), argument_shapes.c_str());
+ to_apply.parameters_size(), arg_shapes.size(), computation_signature,
+ argument_shapes);
}
// All arguments must be compatible with the program shape.
@@ -2482,8 +2498,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Call parameter must match argument; got parameter %d shape: %s, "
"argument shape: %s.",
- i, ShapeUtil::HumanString(param_shape).c_str(),
- ShapeUtil::HumanString(arg_shape).c_str());
+ i, ShapeUtil::HumanString(param_shape),
+ ShapeUtil::HumanString(arg_shape));
}
}
@@ -2494,17 +2510,17 @@ static Status ValidateGatherDimensionNumbers(
const Shape& input_shape,
tensorflow::gtl::ArraySlice<int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.offset_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.offset_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.offset_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.offset_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", "));
}
const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
@@ -2515,9 +2531,9 @@ static Status ValidateGatherDimensionNumbers(
int64 offset_dim = dim_numbers.offset_dims(i);
if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Offset dimension %d in gather op is out of bounds; got %lld, but "
+ "Offset dimension %d in gather op is out of bounds; got %d, but "
"should "
- "have been in [0,%lld).",
+ "have been in [0,%d).",
i, offset_dim, output_shape_rank);
}
}
@@ -2526,8 +2542,8 @@ static Status ValidateGatherDimensionNumbers(
start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
"Gather op has %d elements in start_index_map and the "
- "bound of dimension index_vector_dim=%lld of start_indices is "
- "%lld. These two numbers must be equal.",
+ "bound of dimension index_vector_dim=%d of start_indices is "
+ "%d. These two numbers must be equal.",
dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
start_indices_shape[dim_numbers.index_vector_dim()]);
}
@@ -2537,7 +2553,7 @@ static Status ValidateGatherDimensionNumbers(
if (operand_dim_for_start_index_i < 0 ||
operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid start_index_map; domain is [0, %d), got: %d->%lld.",
+ "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
@@ -2546,36 +2562,37 @@ static Status ValidateGatherDimensionNumbers(
dim_numbers.start_index_map().begin(),
dim_numbers.start_index_map().end());
- c_sort(sorted_start_index_map);
+ absl::c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) {
+ if (absl::c_adjacent_find(sorted_start_index_map) !=
+ sorted_start_index_map.end()) {
return InvalidArgument(
"Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.start_index_map(), ", ").c_str());
+ StrJoin(dim_numbers.start_index_map(), ", "));
}
for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
"Invalid collapsed_slice_dims set in gather op; valid range is [0, "
- "%d), got: %lld.",
+ "%d), got: %d.",
input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
"collapsed_slice_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
"Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
}
return Status::OK();
@@ -2593,7 +2610,7 @@ static Status ValidateGatherDimensionNumbers(
if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(start_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape));
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
@@ -2606,15 +2623,15 @@ static Status ValidateGatherDimensionNumbers(
return InvalidArgument(
"Gather index leaf dimension must be within [0, rank(start_indices) + "
"1). rank(start_indices) is %d and gather index leaf dimension is "
- "%lld.",
+ "%d.",
start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
std::vector<int64> expanded_start_indices_shape;
expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
- c_copy(start_indices_shape.dimensions(),
- std::back_inserter(expanded_start_indices_shape));
+ absl::c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
expanded_start_indices_shape.push_back(1);
@@ -2637,8 +2654,8 @@ static Status ValidateGatherDimensionNumbers(
"All components of the offset index in a gather op must either be a "
"offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
"output_slice_sizes=%s, collapsed_slice_dims=%s.",
- slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
- Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
+ slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
+ StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
}
for (int i = 0; i < slice_sizes.size(); i++) {
@@ -2647,7 +2664,7 @@ static Status ValidateGatherDimensionNumbers(
if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
"Slice size at index %d in gather op is out of range, must be "
- "within [0, %lld), got %lld.",
+ "within [0, %d), got %d.",
i, corresponding_input_size + 1, slice_size);
}
}
@@ -2656,7 +2673,7 @@ static Status ValidateGatherDimensionNumbers(
if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
"Gather op can only collapse slice dims with bound 1, but bound is "
- "%lld for index %lld at position %d.",
+ "%d for index %d at position %d.",
slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
gather_dim_numbers.collapsed_slice_dims(i), i);
}
@@ -2670,10 +2687,11 @@ static Status ValidateGatherDimensionNumbers(
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
- bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i);
+ bool is_window_index =
+ absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
- offset_dims_seen)) {
+ while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
offset_dims_seen++;
}
current_bound = slice_sizes[offset_dims_seen++];
@@ -2697,44 +2715,44 @@ Status ValidateScatterDimensionNumbers(
tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.update_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
return InvalidArgument(
"update_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.update_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
dim_numbers.update_window_dims().end()) {
return InvalidArgument(
"update_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", "));
}
const int64 updates_rank = ShapeUtil::Rank(updates_shape);
for (int64 window_dim : dim_numbers.update_window_dims()) {
if (window_dim < 0 || window_dim >= updates_rank) {
return InvalidArgument(
"Invalid update_window_dims set in scatter op; valid range is [0, "
- "%lld). got: %lld.",
+ "%d). got: %d.",
updates_rank, window_dim);
}
}
// Validate inserted_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.inserted_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
return InvalidArgument(
"inserted_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.inserted_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
dim_numbers.inserted_window_dims().end()) {
return InvalidArgument(
"inserted_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", "));
}
for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
return InvalidArgument(
"Invalid inserted_window_dims set in scatter op; valid range is [0, "
- "%d), got: %lld.",
+ "%d), got: %d.",
operand_shape.dimensions_size(), inserted_dim);
}
}
@@ -2744,7 +2762,7 @@ Status ValidateScatterDimensionNumbers(
scatter_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
"Scatter op has %d elements in scatter_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. "
+ "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
"These two numbers must be equal.",
dim_numbers.scatter_dims_to_operand_dims_size(),
dim_numbers.index_vector_dim(),
@@ -2757,20 +2775,20 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
return InvalidArgument(
"Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
+ "got: %d->%d.",
operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
}
}
std::vector<int64> sorted_scatter_dims_to_operand_dims(
dim_numbers.scatter_dims_to_operand_dims().begin(),
dim_numbers.scatter_dims_to_operand_dims().end());
- c_sort(sorted_scatter_dims_to_operand_dims);
- if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
+ absl::c_sort(sorted_scatter_dims_to_operand_dims);
+ if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
sorted_scatter_dims_to_operand_dims.end()) {
return InvalidArgument(
"Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
"got: %s.",
- Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
+ StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
}
return Status::OK();
@@ -2791,7 +2809,7 @@ Status ValidateScatterDimensionNumbers(
if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
return InvalidArgument(
"Scatter indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(scatter_indices_shape).c_str());
+ ShapeUtil::HumanString(scatter_indices_shape));
}
if (scatter_indices_shape.dimensions_size() <
@@ -2800,7 +2818,7 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Scatter index leaf dimension must be within [0, rank(scatter_indices)"
" + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
- "is %lld.",
+ "is %d.",
scatter_indices_shape.dimensions_size(),
scatter_dim_numbers.index_vector_dim());
}
@@ -2822,7 +2840,7 @@ Status ValidateScatterDimensionNumbers(
int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
scatter_dim_numbers.update_window_dims_size();
if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
- return InvalidArgument("Updates tensor must be of rank %lld; got %lld.",
+ return InvalidArgument("Updates tensor must be of rank %d; got %d.",
expected_updates_rank,
ShapeUtil::Rank(updates_shape));
}
@@ -2848,7 +2866,7 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
- "%lld, updates bound is %lld, operand bound is %lld.",
+ "%d, updates bound is %d, operand bound is %d.",
update_window_dim, updates_shape.dimensions(update_window_dim),
max_update_slice_sizes[i]);
}
@@ -2857,7 +2875,7 @@ Status ValidateScatterDimensionNumbers(
int64 scatter_dims_seen = 0;
for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
bool is_update_window_dim =
- c_binary_search(scatter_dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
continue;
}
@@ -2869,8 +2887,8 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Bounds of the scatter dimensions of updates must be same as the "
"bounds of the corresponding dimensions of scatter indices. For "
- "scatter dimension %lld, updates bound is %lld, scatter_indices "
- "bound is %lld.",
+ "scatter dimension %d, updates bound is %d, scatter_indices "
+ "bound is %d.",
i, updates_shape.dimensions(i),
expanded_scatter_indices_shape[scatter_dims_seen]);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 4974ac9916..235b1a4cf3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -136,6 +136,9 @@ class ShapeInference {
static StatusOr<Shape> InferAllToAllTupleShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ // Infers the shape of a collective permute operation.
+ static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
+
// Infers the shape produced by applying the given reduction computation
// shape to the given input operand shape.
//
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 7d7dcac10b..921a984589 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,20 +18,19 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::strings::Appendf;
-
ShapedBuffer::ShapedBuffer(const Shape& on_host_shape,
const Shape& on_device_shape,
const se::Platform* platform, int device_ordinal)
@@ -76,7 +75,7 @@ void ShapedBuffer::clear() {
}
string ShapedBuffer::ToString() const {
- string s = tensorflow::strings::StrCat(
+ string s = absl::StrCat(
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
", on-device shape=" +
@@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const {
shape_str = ShapeUtil::HumanStringWithLayout(subshape);
}
const se::DeviceMemoryBase& memory = buffer(index);
- Appendf(&s, " %s%p (%lld bytes) : %s\n",
- string(index.size() * 2, ' ').c_str(), memory.opaque(),
- memory.size(), shape_str.c_str());
+ absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n",
+ string(index.size() * 2, ' '), memory.opaque(),
+ memory.size(), shape_str);
});
return s;
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
index 0fc2436679..d69e6362e9 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
xla::StreamExecutorMemoryAllocator allocator(platform, executors);
const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {});
const int kDeviceOrdinal = 0;
- auto scoped_buffer = tensorflow::MakeUnique<xla::ScopedShapedBuffer>(
+ auto scoped_buffer = absl::make_unique<xla::ScopedShapedBuffer>(
shape, shape, &allocator, kDeviceOrdinal);
std::unique_ptr<xla::ShapedBuffer> buffer = std::move(scoped_buffer);
buffer = nullptr;
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
index 8cbaac7b37..dd53c7531b 100644
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ b/tensorflow/compiler/xla/service/source_map_util.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/source_map_util.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -26,11 +27,10 @@ Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
string message;
tensorflow::strings::Appendv(&message, format, args);
if (!op_metadata.source_file().empty()) {
- tensorflow::strings::Appendf(&message, " (%s:%d)",
- op_metadata.source_file().c_str(),
- op_metadata.source_line());
+ absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
+ op_metadata.source_line());
}
- return InvalidArgument("%s", message.c_str());
+ return InvalidArgument("%s", message);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h
index 18e2651abb..c5a7e17cb4 100644
--- a/tensorflow/compiler/xla/service/source_map_util.h
+++ b/tensorflow/compiler/xla/service/source_map_util.h
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
-#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/platform/macros.h"
@@ -24,23 +25,40 @@ namespace xla {
namespace source_map_util {
// Creates an INVALID_ARGUMENT status with the given format string.
+template <typename... Args>
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ string message = absl::StrFormat(format, args...);
+ if (!op_metadata.source_file().empty()) {
+ absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
+ op_metadata.source_line());
+ }
+ return InvalidArgument("%s", message);
+}
+
+// Creates an INVALID_ARGUMENT status with the given format string.
//
// Also, attempts to extract the OpMetadata for parameter_number on executable
// and append it to the status message for source mapping to user code.
//
// executable may be nullptr, but parameter_number should not be out of bounds
// or a CHECK-failure may occur.
+template <typename... Args>
Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...)
- TF_PRINTF_ATTRIBUTE(3, 4);
-
-// As above, but takes the parameter metadata directly instead of extracting it
-// from the executable.
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...)
- TF_PRINTF_ATTRIBUTE(2, 3);
+ const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ if (executable != nullptr && executable->has_module()) {
+ const HloModule& module = executable->module();
+ const HloComputation& computation = *module.entry_computation();
+ HloInstruction* param = computation.parameter_instruction(parameter_number);
+ const OpMetadata& metadata = param->metadata();
+ return InvalidParameterArgument(metadata, format, args...);
+ }
+ return InvalidArgument(format, args...);
+}
} // namespace source_map_util
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index c0582c6a2d..5d1cd1c442 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
if (!stream) {
// Create a new stream.
- stream = MakeUnique<se::Stream>(executor);
+ stream = absl::make_unique<se::Stream>(executor);
stream->Init();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool created new stream";
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 32d368a904..b8d2d546e5 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
namespace xla {
/* static */ tensorflow::mutex
@@ -61,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return absl::make_unique<Literal>(std::move(literal));
}
Status TransferManager::TransferLiteralFromDevice(
@@ -120,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return absl::make_unique<Literal>(std::move(literal));
}
Status TransferManager::TransferArrayToDevice(
@@ -147,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
return FailedPrecondition(
"Allocation on device not large enough for array: "
- "%lld < %lld",
+ "%d < %d",
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
@@ -164,12 +166,12 @@ void TransferManager::TransferArrayFromDevice(
auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
" has a differently shaped representation on-device: ",
ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
- return done(FailedPrecondition("%s", error.c_str()));
+ return done(FailedPrecondition("%s", error));
}
if (source.size() < GetByteSizeRequirement(shape)) {
return done(
FailedPrecondition("Allocation on device not large enough for array: "
- "%lld < %lld",
+ "%d < %d",
source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
@@ -201,7 +203,7 @@ void TransferManager::TransferArrayFromDevice(
return NotFound(
"could not find registered transfer manager for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
if (it->second.manager == nullptr) {
@@ -252,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice(
if (source.size() < size) {
return FailedPrecondition(
"Source allocation on device not large enough for data tranfer: "
- "%lld < %lld",
+ "%d < %d",
source.size(), size);
}
stream->ThenMemcpy(destination, source, size);
@@ -265,7 +267,7 @@ Status TransferManager::TransferBufferToDevice(
if (destination->size() < size) {
return FailedPrecondition(
"Destination allocation on device not large enough for data tranfer: "
- "%lld < %lld",
+ "%d < %d",
destination->size(), size);
}
stream->ThenMemcpy(destination, source, size);
@@ -276,9 +278,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal) {
if (!LayoutUtil::HasLayout(on_host_shape)) {
- return InvalidArgument(
- "Shape must have a layout: %s",
- ShapeUtil::HumanStringWithLayout(on_host_shape).c_str());
+ return InvalidArgument("Shape must have a layout: %s",
+ ShapeUtil::HumanStringWithLayout(on_host_shape));
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 475a2e5c14..f77690a462 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -152,6 +152,26 @@ class TransferManager {
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal);
+ // The given ShapedBuffer holds a handle to allocated memory, but it is not
+ // in the general case legal to immediately copy or access that allocated
+ // memory because queued operations on the device may alias that memory.
+ // Memory ordering is enforced by the Stream's happens-before relationship
+ // which allows eager deallocation and reallocation of buffers host-side even
+ // if the device hasn't finished with them.
+ //
+ // In certain cases, it can be known that a ShapedBuffer does not have any
+ // conflicting accesses on the device and thus is eligible to be accessed at
+ // any time from the host.
+ //
+ // This function returns true if device_buffer can be accessed immediately
+ // without waiting for the Stream's previously enqueued items. This only
+ // returns true if all subbuffers in device_buffer can be accessed
+ // immediately.
+ virtual bool CanShapedBufferBeAccessedNow(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
+ return false;
+ }
+
/////
// The TransferManager class also serves as a point to register objects for
// the various platforms.
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 49e1f87319..530f40e4b2 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
dot->shape(), new_lhs, new_rhs, new_dim_numbers);
+ new_dot->set_precision_config(dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
+ new_conv->set_precision_config(convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 71e8446452..3e5aa2db60 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface {
explicit TransposeFolding(
TransposableGemmOperandsFn transposable_gemm_operands,
TransposableConvOperandsFn transposable_conv_operands);
- tensorflow::StringPiece name() const override { return "transpose-folding"; }
+ absl::string_view name() const override { return "transpose-folding"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0447807a41..cf00ca102b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -19,6 +19,10 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -26,17 +30,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
string BufferAlias::ToString() const {
- return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "])");
+ return absl::StrCat("BufferAlias(", instruction_->name(), "[",
+ absl::StrJoin(index_, ","), "])");
}
std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
@@ -441,7 +441,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
PerInstruction* pi = PerInst(instruction);
CHECK(pi->points_to_set == nullptr)
<< "instruction should not have been present in the map.";
- auto set = MakeUnique<PointsToSet>(&instruction->shape());
+ auto set = absl::make_unique<PointsToSet>(&instruction->shape());
pi->points_to_set = std::move(set);
// Return *set using the iterator returned by emplace.
return *pi->points_to_set;
@@ -462,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
return FailedPrecondition(
"LogicalBuffer %s is ill-defined: instruction %s does not define a "
"buffer at that index",
- buffer.ToString().c_str(), buffer.instruction()->name().c_str());
+ buffer.ToString(), buffer.instruction()->name());
}
}
if (buffer.id() < 0 ||
buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
- return FailedPrecondition(
- "LogicalBuffer %s is ill-defined: invalid id %lld",
- buffer.ToString().c_str(), buffer.id());
+ return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
+ buffer.ToString(), buffer.id());
}
if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
GetBuffer(buffer.id()).index() != buffer.index()) {
return FailedPrecondition(
"LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
- buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str());
+ buffer.ToString(), GetBuffer(buffer.id()).ToString());
}
return Status::OK();
@@ -495,8 +494,7 @@ StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
return FailedPrecondition(
"instruction %s does not define buffer at index {%s}",
- instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str());
+ instruction->name(), absl::StrJoin(index, ","));
}
return buffers[0];
}
@@ -557,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
}
string TuplePointsToAnalysis::ToString() const {
- string output = tensorflow::strings::Printf(
- "TuplePointsToSet for module %s:\n", module_->name().c_str());
+ string output =
+ absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
for (const auto* computation : module_->MakeNonfusionComputations()) {
const char* entry =
computation == module_->entry_computation() ? "entry " : "";
- tensorflow::strings::StrAppend(&output, entry, "computation ",
- computation->name(), ":\n");
+ absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
InstructionToString(instruction, &output);
@@ -575,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const {
}
}
- tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n");
+ absl::StrAppend(&output, "LogicalBuffers:\n");
for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
- tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n");
+ absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
- tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(),
- "\n");
+ absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
}
}
return output;
@@ -589,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const {
void TuplePointsToAnalysis::InstructionToString(
const HloInstruction* instruction, string* output) const {
const string prefix = instruction->IsFused() ? " " : "";
- tensorflow::strings::StrAppend(output, prefix, " instruction ",
- instruction->ToShortString(), ":\n");
+ absl::StrAppend(output, prefix, " instruction ",
+ instruction->ToShortString(), ":\n");
const PointsToSet& points_to_set = GetPointsToSet(instruction);
points_to_set.ForEachElement([&prefix, &output](
const ShapeIndex& index,
const PointsToSet::BufferList& points_to) {
- tensorflow::strings::StrAppend(
- output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ",
- tensorflow::str_util::Join(
- points_to, ", ",
- [](string* out, const LogicalBuffer* source) {
- out->append(source->ToString());
- }),
- "\n");
+ absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
+ absl::StrJoin(points_to, ", ",
+ [](string* out, const LogicalBuffer* source) {
+ out->append(source->ToString());
+ }),
+ "\n");
});
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 686bb05328..62c7bb685d 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -109,7 +110,7 @@ class PointsToSet {
// Add a tuple source instruction for the given index.
void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
- using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
+ using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>;
// Return the list of logical buffers for the subshape at index.
const BufferList& element(const ShapeIndex& index) const {
@@ -203,7 +204,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
// logical buffer The buffer alias set is the inverse of the points-to set.
// That is, LogicalBuffer B is in the points-to set of instruction I at index
// N iff instruction I, index N is a BufferAlias of B.
- using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>;
+ using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>;
const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
// Returns the number of logical buffers in the module
@@ -226,8 +227,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
// instructions produce a single buffer (the top-level buffer), some produce
// no buffers (eg bitcast), and some produce more than one buffer (eg,
// tuple-shaped parameters).
- using BufferDefinitionVector =
- tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
+ using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>;
const BufferDefinitionVector& GetBuffersDefinedByInstruction(
const HloInstruction* instruction) const;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 7509501883..8c91d6e69d 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface {
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
- tensorflow::StringPiece name() const override { return "tuple-simplifier"; }
+ absl::string_view name() const override { return "tuple-simplifier"; }
// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index af2cb6dc2a..7e4ac92a7c 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -18,8 +18,8 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::nullopt;
+using absl::optional;
// Finds and returns the non-constant operand in instr.
//
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h
index bf59813e8c..bf497f4892 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.h
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -25,8 +25,8 @@ namespace xla {
// nullopt otherwise. max_value_returned limits the number of steps that are
// evaluated while trying to brute force a loop trip count, trip counts larger
// than max_value_returned result in nullopt.
-tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount(
- HloInstruction *while_op, int64 max_value_returned = 128);
+absl::optional<int64> ComputeWhileLoopTripCount(HloInstruction *while_op,
+ int64 max_value_returned = 128);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 62af45128a..aab1180662 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance(
std::vector<HloInstruction*> users;
users.reserve(old_instr->user_count());
- c_copy(old_instr->users(), std::back_inserter(users));
+ absl::c_copy(old_instr->users(), std::back_inserter(users));
for (auto* user : users) {
for (int64 i = 0, e = user->operand_count(); i < e; i++) {
@@ -108,10 +109,10 @@ StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) {
//
// This will let us sink the constant into the outer while first and then
// into the inner while in a single run of this pass.
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 21fb8568a8..2dba7d7f75 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface {
public:
~WhileLoopConstantSinking() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 09ddcffb22..f4098f28b3 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -14,18 +14,19 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
+using absl::InlinedVector;
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-using tensorflow::gtl::InlinedVector;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
@@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy(
};
InlinedVector<HloInstruction*, 4> new_operands;
- c_transform(old_instruction->operands(), std::back_inserter(new_operands),
- get_new_operand);
+ absl::c_transform(old_instruction->operands(),
+ std::back_inserter(new_operands), get_new_operand);
HloInstruction* new_instruction =
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
@@ -197,7 +198,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
op->opcode() == HloOpcode::kConstant;
};
- if (!c_all_of(instruction->operands(), is_invariant)) {
+ if (!absl::c_all_of(instruction->operands(), is_invariant)) {
continue;
}
@@ -257,10 +258,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8e6cc87875..2cdf20ce80 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface {
: hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 32e69c335b..e14014b961 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
+ WhileLoopInvariantCodeMotionTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index dd8697e680..6a7bfe3f12 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,17 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::optional;
// Determines whether the given instruction is a send/recv node, or has a
// subcomputation which contains a send/recv node.
@@ -237,12 +236,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
<< "Instruction " << user->ToString(print_no_metadata)
<< " should be unused (except by root of while body), but has "
"users: {"
- << tensorflow::str_util::Join(
- user->users(), ", ",
- [&](string* out, const HloInstruction* instr) {
- tensorflow::strings::StrAppend(
- out, instr->ToString(print_no_metadata));
- })
+ << absl::StrJoin(user->users(), ", ",
+ [&](string* out, const HloInstruction* instr) {
+ absl::StrAppend(
+ out, instr->ToString(print_no_metadata));
+ })
<< "}";
replacements.emplace(user, nullptr);
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 3d3e1d60f2..78024f14dc 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -33,9 +33,7 @@ namespace xla {
class WhileLoopSimplifier : public HloPassInterface {
public:
~WhileLoopSimplifier() override {}
- tensorflow::StringPiece name() const override {
- return "simplify-while-loops";
- }
+ absl::string_view name() const override { return "simplify-while-loops"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 2e1571943e..cfe4104f6d 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -15,11 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -27,6 +28,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
+ public:
+ WhileLoopSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
@@ -64,10 +70,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
@@ -103,10 +107,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 1ef17b9d7d..e8f76ff745 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -14,15 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
static StatusOr<HloComputation*> WidenWhileCondition(
HloComputation* narrow_condition, const Shape& wide_shape) {
@@ -206,7 +207,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
- c_copy(init_values, std::back_inserter(init_values_with_indvar));
+ absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
HloInstruction::CreateTuple(init_values_with_indvar));
}
@@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
- c_transform(init_values, std::back_inserter(loop_state_shape_components),
- [](HloInstruction* instr) { return instr->shape(); });
+ absl::c_transform(init_values,
+ std::back_inserter(loop_state_shape_components),
+ [](HloInstruction* instr) { return instr->shape(); });
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index 2ccb919acf..5e69419333 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
@@ -206,7 +207,7 @@ ENTRY main {
auto is_while = [](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kWhile;
};
- EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+ EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index 8763e588c4..a7f0e207eb 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -24,7 +24,7 @@ namespace xla {
class ZeroSizedHloElimination : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "zero_sized_hlo_elimination";
}
};