aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/aot/tests/BUILD15
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py12
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc25
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
-rw-r--r--tensorflow/compiler/jit/BUILD47
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc142
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc162
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.h (renamed from tensorflow/compiler/jit/build_xla_launch_ops_pass.h)10
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc138
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc214
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h4
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc31
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc56
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc12
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc4
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD8
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc276
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h87
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc500
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.h168
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc80
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc77
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc21
-rw-r--r--tensorflow/compiler/jit/ops/BUILD8
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc39
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc7
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc5
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc33
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h35
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc9
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc10
-rw-r--r--tensorflow/compiler/jit/xla_device.cc12
-rw-r--r--tensorflow/compiler/jit/xla_device.h12
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h16
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h17
-rw-r--r--tensorflow/compiler/tests/BUILD45
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py4
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py25
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl10
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py25
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py40
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py55
-rw-r--r--tensorflow/compiler/tests/jit_test.py48
-rw-r--r--tensorflow/compiler/tests/lstm.py2
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py43
-rw-r--r--tensorflow/compiler/tests/permute_test.py80
-rw-r--r--tensorflow/compiler/tests/quantized_ops_test.py48
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py19
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc16
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py2
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py38
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py9
-rw-r--r--tensorflow/compiler/tests/tensor_list_ops_test.py96
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/xla_test.py19
-rw-r--r--tensorflow/compiler/tf2xla/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc12
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc146
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD26
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc41
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc64
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc509
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc551
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc76
-rw-r--r--tensorflow/compiler/tf2xla/kernels/permute_op.cc98
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc17
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc226
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD16
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.cc93
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.h (renamed from tensorflow/compiler/xla/service/gpu/gpu_options.cc)24
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc213
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h6
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc32
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py18
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc14
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table_test.cc3
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc14
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.h5
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc30
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h51
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h8
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc17
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc40
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc24
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h31
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc12
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc30
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h38
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc10
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h8
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc13
-rw-r--r--tensorflow/compiler/xla/literal.cc58
-rw-r--r--tensorflow/compiler/xla/literal.h16
-rw-r--r--tensorflow/compiler/xla/literal_test.cc10
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc11
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h5
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py44
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD13
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc21
-rw-r--r--tensorflow/compiler/xla/service/BUILD185
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h10
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc1
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h19
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_support.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc94
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h39
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.h5
-rw-r--r--tensorflow/compiler/xla/service/buffer_value_containers.h4
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc9
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h16
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc17
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD24
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc18
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc122
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h44
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc210
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h16
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc236
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h88
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc54
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/defuser.cc3
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc169
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/fusion_queue.h53
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD51
-rw-r--r--tensorflow/compiler/xla/service/gpu/backend_configs.proto14
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc163
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h13
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc56
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc194
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc278
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h (renamed from tensorflow/compiler/xla/service/gpu/gpu_options.h)22
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc105
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc118
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h56
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc35
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD60
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc283
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc265
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h98
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc252
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto12
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_clone_context.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc97
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h30
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h23
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc201
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc174
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h173
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc111
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h48
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc56
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h26
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc84
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc123
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h31
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc390
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc247
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h35
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc194
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc259
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc487
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc34
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h42
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc5
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc111
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h33
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc106
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h11
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/map_inliner.cc (renamed from tensorflow/compiler/xla/service/inliner.cc)51
-rw-r--r--tensorflow/compiler/xla/service/map_inliner.h (renamed from tensorflow/compiler/xla/service/inliner.h)22
-rw-r--r--tensorflow/compiler/xla/service/map_inliner_test.cc (renamed from tensorflow/compiler/xla/service/inliner_test.cc)46
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h5
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h8
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h762
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc183
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc10
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h3
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc78
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc13
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc4
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc10
-rw-r--r--tensorflow/compiler/xla/service/stream_pool_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc23
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc1
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc16
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc14
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc17
-rw-r--r--tensorflow/compiler/xla/shape_util.h4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD51
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl488
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc14
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h8
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc78
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h63
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc158
-rw-r--r--tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc120
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc8
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc2
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD6
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc41
360 files changed, 11553 insertions, 4862 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d44d..10fa33ab5e 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@ py_binary(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@ genrule(
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@ tf_cc_test(
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..64b861a730 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -46,7 +47,7 @@ def tfadd(_):
def tfadd_with_ckpt(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
@@ -61,7 +62,7 @@ def tfadd_with_ckpt(out_dir):
def tfadd_with_ckpt_saver(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000000..6b4ac2d7cb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb8a9..f10852c785 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe14a..859c84bb91 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@ def tf_library(
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 1001c57f3d..661b444a42 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -50,7 +51,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -62,7 +63,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda([
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
@@ -76,7 +77,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -94,7 +95,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
@@ -111,7 +112,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
@@ -257,6 +258,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -280,7 +282,7 @@ cc_library(
deps = [
":common",
":compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -322,6 +324,8 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -341,7 +345,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -359,7 +363,7 @@ tf_cc_test(
cc_library(
name = "compilation_passes",
srcs = [
- "build_xla_launch_ops_pass.cc",
+ "build_xla_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@@ -369,7 +373,7 @@ cc_library(
"partially_decluster_pass.cc",
],
hdrs = [
- "build_xla_launch_ops_pass.h",
+ "build_xla_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@@ -382,12 +386,16 @@ cc_library(
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
@@ -399,6 +407,8 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -459,7 +469,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -470,6 +480,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -477,6 +488,7 @@ tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
+ "build_xla_ops_pass_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
@@ -485,6 +497,7 @@ tf_cc_test(
deps = [
":common",
":compilation_passes",
+ ":node_matchers",
":xla_cluster_util",
":xla_gpu_device",
"//tensorflow/cc:cc_ops",
@@ -493,7 +506,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -506,6 +519,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -524,7 +538,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -628,6 +642,15 @@ tf_cc_test(
],
)
+tf_custom_op_py_library(
+ name = "xla_ops_py",
+ kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ visibility = [
+ ":friends",
+ ],
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff589e2..0000000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
- Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("XlaLaunch");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
- VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
- int num_constant_args, num_resource_args;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
- if (num_constant_args < 0 || num_resource_args < 0 ||
- num_constant_args + num_resource_args > node->num_inputs()) {
- return errors::InvalidArgument(
- "Invalid number of constant/resource arguments to XLA kernel.");
- }
- const int num_nonconst_args =
- node->num_inputs() - num_constant_args - num_resource_args;
-
- DataTypeVector const_dtypes(node->input_types().begin(),
- node->input_types().begin() + num_constant_args);
- DataTypeVector arg_dtypes(
- node->input_types().begin() + num_constant_args,
- node->input_types().begin() + num_constant_args + num_nonconst_args);
-
- // Build a XlaLaunch operator to execute the function body.
- Node* launch_node;
- TF_RETURN_IF_ERROR(BuildLaunchNode(
- graph->NewName(node->name()), node->type_string(), node->def().attr(),
- node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
- node->output_types(), graph, &launch_node));
- launch_node->set_assigned_device_name(node->assigned_device_name());
-
- // Copy incoming edges to the launch node.
- for (const Edge* edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(edge->src(), launch_node);
- } else {
- graph->AddEdge(edge->src(), edge->src_output(), launch_node,
- edge->dst_input());
- }
- }
-
- // Copy outgoing edges to the launch node.
- std::vector<const Edge*> out_edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(launch_node, dst);
- } else {
- graph->AddEdge(launch_node, src_output, dst, dst_input);
- }
- }
- graph->RemoveNode(node);
-
- return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
-
- for (Node* n : graph->op_nodes()) {
- // In all cases, only try to compile computational nodes.
- if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
- continue;
- }
-
- // Only compile nodes that are marked for compilation by the
- // compilation-marking pass (via 'attr_name').
- if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
- }
- }
-
- if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000000..5974696b77
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,162 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
+ // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
+ // the NodeDef inputs to the function call nodes.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
+}
+
+struct XlaClusterInfo {
+ std::vector<Output> constant_inputs;
+ std::vector<Output> non_constant_inputs;
+ std::vector<Output> resource_inputs;
+ NameAttrList function;
+};
+
+Output IncomingEdgeAsOutput(const Edge* e) {
+ return Output(e->src(), e->src_output());
+}
+
+Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
+ int num_constant_inputs, num_resource_inputs;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
+
+ if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
+ num_constant_inputs + num_resource_inputs > n->num_inputs()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant/resource arguments to XLA kernel.");
+ }
+
+ int num_non_constant_inputs =
+ n->num_inputs() - num_constant_inputs - num_resource_inputs;
+
+ std::vector<const Edge*> input_edges_vector;
+ TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
+ absl::Span<const Edge*> input_edges(input_edges_vector);
+
+ absl::c_transform(input_edges.subspan(0, num_constant_inputs),
+ std::back_inserter(result->constant_inputs),
+ IncomingEdgeAsOutput);
+
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
+ std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
+
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
+ num_resource_inputs),
+ std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
+
+ result->function.set_name(n->type_string());
+ *result->function.mutable_attr() = n->def().attr();
+ return Status::OK();
+}
+
+Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
+ for (const Edge* e : from->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), to);
+ }
+ }
+
+ return Status::OK();
+}
+
+Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) {
+ Status status;
+ Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
+ .NewSubScope(n->name())
+ .WithDevice(n->requested_device())
+ .WithAssignedDevice(n->assigned_device_name());
+
+ XlaClusterInfo cluster_info;
+ TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
+
+ ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
+ /*constants=*/cluster_info.constant_inputs,
+ /*args=*/cluster_info.non_constant_inputs,
+ /*resources=*/cluster_info.resource_inputs,
+ cluster_info.function);
+ TF_RETURN_IF_ERROR(
+ CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
+
+ std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
+ absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+ ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
+ xla_compile.key, n->output_types());
+
+ MoveOutgoingEdges(g, /*old_node=*/n,
+ /*new_node=*/xla_run.operation.node());
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+} // namespace
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->op_nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n));
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93f02..1dd38fa951 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
new file mode 100644
index 0000000000..9d56db7b6b
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/node_matchers.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::testing::FindNodeByName;
+using ::tensorflow::testing::matchers::CtrlDeps;
+using ::tensorflow::testing::matchers::NodeWith;
+using ::tensorflow::testing::matchers::Op;
+
+Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
+
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : graph->nodes()) {
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = &graph;
+ BuildXlaOpsPass pass;
+ TF_RETURN_IF_ERROR(pass.Run(opt_options));
+ *result = std::move(graph);
+ return Status::OK();
+}
+
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, int num_constant_args,
+ int num_resource_args, Node** result) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
+ Status s;
+ *result = graph->AddNode(call_node, &s);
+ return s;
+}
+
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ return MakeXlaCompiledKernel(graph, callee_name, node_name,
+ /*num_constant_args=*/0, /*num_resource_args=*/0,
+ result);
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+TEST(BuildXlaOps, ControlDepsPreserved) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(BuildXlaOps(root, &graph));
+
+ Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
+ ASSERT_NE(write_op_new, nullptr);
+ EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
+}
+
+TEST(BuildXlaOps, CleanFailureOnBogusAttr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(
+ MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ Status failure_status = BuildXlaOps(root, &graph);
+ ASSERT_FALSE(failure_status.ok());
+ EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a30b..6f1ff85f24 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 9128b48da3..b7ae7fbeb3 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,11 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
@@ -296,7 +299,7 @@ class SymbolPredicate : public Predicate {
template <typename FunctionTy>
/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
- gtl::FlatSet<Predicate*> visited;
+ absl::flat_hash_set<Predicate*> visited;
std::vector<Predicate*> stack;
stack.push_back(p);
@@ -383,6 +386,8 @@ class PredicateFactory {
}
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
+ Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
+ Predicate::Kind pred_kind);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -417,24 +422,53 @@ class PredicateFactory {
}
};
- gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
- HashSignatureForAndOr>
+ absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
+ HashSignatureForAndOr>
interned_and_or_instances_;
- gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
interned_not_instances_;
- gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
interned_and_rec_instances_;
- gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
- HashSignatureForSymbol>
+ absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
+ HashSignatureForSymbol>
interned_symbol_instances_;
};
+Predicate* PredicateFactory::MakeInternedAndOr(
+ std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
+ std::stable_sort(
+ simplified_ops.begin(), simplified_ops.end(),
+ [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+
+ auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
+ if (it != interned_and_or_instances_.end()) {
+ return it->second.get();
+ }
+
+ simplified_ops.shrink_to_fit();
+ // NB! Because we'll use a non-owning reference to simplified_ops in the
+ // key for interned_and_or_instances_ we need to be careful to std::move()
+ // it all the way through.
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
+ std::unique_ptr<Predicate> new_pred =
+ pred_kind == Predicate::Kind::kAnd
+ ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_and_or_instances_.emplace(
+ SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
+ return new_pred_ptr;
+}
+
// Common code to create AndPredicate or OrPredicate instances.
Predicate* PredicateFactory::MakeAndOrImpl(
absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
- gtl::FlatSet<Predicate*> simplified_ops_set;
+ Predicate::Kind other_pred_kind =
+ is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
+ absl::flat_hash_set<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@@ -459,7 +493,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
// Simplify "A&~A=>False" and "A|~A=>True".
- gtl::FlatSet<Predicate*> negated_ops;
+ absl::flat_hash_set<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@@ -472,30 +506,63 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
}
- std::stable_sort(
- simplified_ops.begin(), simplified_ops.end(),
- [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+ // If all ops contain the same subop, then factor it out thanks to the
+ // distributive property. Such as:
+ // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
+ // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
+ //
+ // First find any predicates contained in all subops.
+ std::vector<Predicate*> common_inner_operands;
+ absl::flat_hash_set<Predicate*> common_inner_operands_set;
+ for (Predicate* op : simplified_ops) {
+ if (op->kind() != other_pred_kind) {
+ common_inner_operands.clear();
+ break;
+ }
- auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
- if (it == interned_and_or_instances_.end()) {
- simplified_ops.shrink_to_fit();
- // NB! Because we'll use a non-owning reference to simplified_ops in the
- // key for interned_and_or_instances_ we need to be careful to std::move()
- // it all the way through.
- absl::Span<Predicate* const> operands_slice = simplified_ops;
- std::unique_ptr<Predicate> new_pred =
- is_and ? Make<AndPredicate>(std::move(simplified_ops))
- : Make<OrPredicate>(std::move(simplified_ops));
+ if (common_inner_operands.empty()) {
+ common_inner_operands.insert(common_inner_operands.end(),
+ op->GetOperands().begin(),
+ op->GetOperands().end());
+ } else {
+ std::vector<Predicate*> sub_ops_intersection;
+ common_inner_operands.clear();
+ absl::c_copy_if(op->GetOperands(),
+ std::back_inserter(common_inner_operands),
+ [&](Predicate* sub_op) {
+ return common_inner_operands_set.count(sub_op) == 1;
+ });
+ }
+ if (common_inner_operands.empty()) break;
+ common_inner_operands_set.clear();
+ common_inner_operands_set.insert(common_inner_operands.begin(),
+ common_inner_operands.end());
+ }
- Predicate* new_pred_ptr = new_pred.get();
- CHECK(interned_and_or_instances_
- .emplace(SignatureForAndOr(pred_kind, operands_slice),
- std::move(new_pred))
- .second);
- return new_pred_ptr;
- } else {
- return it->second.get();
+ if (common_inner_operands.empty()) {
+ return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
}
+
+ // For all predicates that can be factored out, remove them and recreate the
+ // subops.
+ std::vector<Predicate*> factored_ops;
+ for (Predicate* op : simplified_ops) {
+ std::vector<Predicate*> new_sub_op_ops;
+ absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
+ [&](Predicate* sub_op) {
+ return std::find(common_inner_operands.begin(),
+ common_inner_operands.end(),
+ sub_op) == common_inner_operands.end();
+ });
+ factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
+ }
+
+ Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
+ std::vector<Predicate*> outer_ops;
+ outer_ops.push_back(new_inner_op);
+ outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
+ common_inner_operands.end());
+ return MakeAndOrImpl(outer_ops, !is_and);
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
@@ -507,12 +574,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
- gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
+ const;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
- std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
+ Status GetInputPreds(Node* n, EdgeKind edge_kind,
+ std::vector<Predicate*>* result);
// Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
// bit of `should_revisit` if `pred` is different from the current predicate
@@ -549,7 +618,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
const Graph& graph_;
- gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
+ absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
bool vlog_;
};
@@ -558,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) {
return TensorId(e->src()->name(), e->src_output());
}
-std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
- Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
- std::vector<Predicate*> incoming_preds;
+Status DeadnessAnalysisImpl::GetInputPreds(
+ Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
+ std::vector<Predicate*>* result) {
+ result->clear();
for (const Edge* in_edge : n->in_edges()) {
bool should_process =
edge_kind == EdgeKind::kDataAndControl ||
@@ -569,17 +639,27 @@ std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
if (should_process) {
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
- CHECK(it != predicate_map_.end()) << n->name();
- incoming_preds.push_back(it->second);
+ if (it == predicate_map_.end()) {
+ GraphCycles graph_cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
+
+ // If we didn't return with an error above then the graph is probably
+ // fine and we have a bug in deadness analysis.
+ return errors::Internal("Could not find input ", in_edge->DebugString(),
+ " to ", n->name(),
+ " when visiting the graph in post-order. Most "
+ "likely indicates a bug in deadness analysis.");
+ }
+ result->push_back(it->second);
}
}
- return incoming_preds;
+ return Status::OK();
}
Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
std::vector<bool>* should_revisit) {
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
const Edge* pred_edge;
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
@@ -608,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
}
namespace {
-const Edge* FindUniqueBackedge(Node* merge) {
+Status CreateMultipleNextIterationInputsError(Node* merge) {
+ std::vector<string> backedges;
+ for (const Edge* backedge : merge->in_edges()) {
+ if (backedge->src()->IsNextIteration()) {
+ backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
+ }
+ }
+ return errors::InvalidArgument(
+ "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge),
+ ": \n", absl::StrJoin(backedges, "\n"),
+ "\nMerge nodes can have at most one incoming NextIteration edge.");
+}
+
+Status FindUniqueBackedge(Node* merge, const Edge** result) {
+ *result = nullptr;
CHECK(merge->IsMerge());
- const Edge* result = nullptr;
for (const Edge* e : merge->in_edges()) {
if (e->src()->IsNextIteration()) {
- CHECK_EQ(result, nullptr)
- << "Multiple backedges to " << merge->DebugString();
- result = e;
+ if (*result != nullptr) {
+ return CreateMultipleNextIterationInputsError(merge);
+ }
+ *result = e;
}
}
- return result;
+ return Status::OK();
}
// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
@@ -697,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
return Status::OK();
}
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
+
// We're visiting this merge for the first time and it is a acyclic merge.
- Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
- GetIncomingPreds(n, EdgeKind::kDataOnly));
+ Predicate* input_data_pred =
+ predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit);
return Status::OK();
@@ -710,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// of an unvisited backedge. Try to pattern match the predicate expression
// for that backedge (which should be visited now) into an and recurrence
// for the merge node.
- if (const Edge* unique_backedge = FindUniqueBackedge(n)) {
+ const Edge* unique_backedge;
+ TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
+ if (unique_backedge) {
if (Predicate* step = DeduceStepPredicate(
&predicate_factory_, it->second,
predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
@@ -741,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
std::vector<bool>* should_revisit) {
// In addition to being alive or dead based on the inputs, a _Recv can also
// acquire a dead signal from a _Send.
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
SetPredicate(n, {0, Graph::kControlSlot},
@@ -754,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
std::vector<bool>* should_revisit) {
// Generally nodes are alive iff all their inputs are alive.
- Predicate* pred = predicate_factory_.MakeAndPredicate(
- GetIncomingPreds(n, EdgeKind::kDataAndControl));
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
+ Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
SetPredicate(n, output_idx, pred, should_revisit);
}
@@ -912,9 +1012,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
return Status::OK();
}
-gtl::FlatMap<TensorId, string, TensorId::Hasher>
+absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
- gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index 3df2679c62..354782374a 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
-using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
+using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
// Returns a map describing the predicate each Tensor was mapped to. For
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 28a56044d5..617e31488c 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
-TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
- // This demonstrates one of the weaknesses in the current approach -- since we
- // only do some basic simplifications we can't see that "(A|B)&C" ==
- // "(A&C)|(B&C)".
+TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
+ // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "A");
+ ops::Switch sw_1 = CreateSwitch(root, "B");
+ Output add0 =
+ ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
+ Output add1 =
+ ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
+ ops::Merge or2(root.WithOpName("or2"), {add0, add1});
+ Output add3 =
+ ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
+ ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+ EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
+}
+
+TEST(DeadnessAnalysisTest, AndOrDistributive) {
+ // (A|B)&C == (A&C)|(B&C)
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
@@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
}
TEST(DeadnessAnalysisTest, Ternary) {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index e0632ff7e4..da27f837e8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/session_options.h"
@@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) {
namespace {
bool AreAllParentsGuaranteedConst(
- const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ const Node& n,
+ const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
@@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst(
void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
- gtl::FlatSet<const Node*> guaranteed_const_nodes;
+ absl::flat_hash_set<const Node*> guaranteed_const_nodes;
std::vector<const Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
@@ -748,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
graph_->set_versions(graph_in->versions());
}
+ // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
+ // determined. In case of hard placement, ensure all the encapsulated nodes
+ // have the same requested device, which in turn will be the requested device
+ // for the entire encapsulated subgraph. In case of soft placement, use a
+ // deterministic approach to fill in the requested device. Handle co-location
+ // constraints similarly if they exist.
if (device_.empty()) {
device_ = node->assigned_device_name().empty()
? node->requested_device()
@@ -1357,28 +1364,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
Status Encapsulator::GetFunctionNameAttr(
Node const* node, string* attr, string* outside_compilation_attr) const {
- Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
- if (s.code() == error::Code::NOT_FOUND) {
- // Return empty attr if there's no group_attribute.
- attr->clear();
- } else {
- TF_RETURN_IF_ERROR(s);
- }
- bool has_group_attr = s.ok();
- s = GetNodeAttr(node->attrs(), outside_compilation_attribute_,
- outside_compilation_attr);
- if (s.code() == error::Code::NOT_FOUND) {
- // Return empty attr if there's no outside_compilation attribute.
- outside_compilation_attr->clear();
- } else {
- TF_RETURN_IF_ERROR(s);
- if (!has_group_attr) {
- return errors::InvalidArgument(
- "Node ", node->name(), " has ", outside_compilation_attribute_,
- " attribute but no ", group_attribute_, " attribute.");
+ AttrSlice attrs = node->attrs();
+ attr->clear();
+ outside_compilation_attr->clear();
+ bool found_group_attribute = false;
+ bool found_outside_compilation_attribute = false;
+ for (const auto& node_attr : attrs) {
+ if (node_attr.first == group_attribute_) {
+ TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
+ *attr = node_attr.second.s();
+ found_group_attribute = true;
+ } else if (node_attr.first == outside_compilation_attribute_) {
+ TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
+ *outside_compilation_attr = node_attr.second.s();
+ found_outside_compilation_attribute = true;
}
+ if (found_group_attribute && found_outside_compilation_attribute) break;
+ }
+
+ if (found_outside_compilation_attribute && !found_group_attribute) {
+ return errors::InvalidArgument(
+ "Node ", node->name(), " has ", outside_compilation_attribute_,
+ " attribute but no ", group_attribute_, " attribute.");
+ } else {
+ return Status::OK();
}
- return Status::OK();
}
bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 97ef8cd3cb..2ce6fa73fc 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) {
}
// Adds the control inputs of `node` to `*deps`.
-void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.in_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->src());
@@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
}
// Adds the control outputs of `node` to `*deps`.
-void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.out_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->dst());
@@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Data and control inputs to the new XlaLaunch node.
std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
- gtl::FlatSet<Node*> control_inputs;
+ absl::flat_hash_set<Node*> control_inputs;
DataTypeVector arg_types(num_args);
AddControlInputs(*launch, &control_inputs);
@@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Outputs.
const int num_outputs = launch->output_types().size();
- gtl::FlatSet<Node*> control_outputs;
+ absl::flat_hash_set<Node*> control_outputs;
std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
DataTypeVector output_types(num_outputs);
@@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Target the XLA CPU/GPU backends.
VLOG(2) << "Replacing with XlaLaunch";
+ VLOG(2) << "Device is " << launch->requested_device();
def.set_op("XlaLaunch");
+ def.set_device(launch->requested_device());
AddNodeAttr("Tconstants", DataTypeVector{}, &def);
AddNodeAttr("Targs", arg_types, &def);
AddNodeAttr("Nresources", num_variables, &def);
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index f643fb0cfe..22531a4ace 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -55,6 +55,7 @@ static std::unique_ptr<Graph> MakeOuterGraph(
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
+ .Device("/gpu:0")
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
@@ -107,10 +108,11 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
-
+ add_attrs(b_identity.node());
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
add_attrs(read_u.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
@@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
@@ -317,8 +320,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
NameAttrList function;
function.set_name("launch0");
auto launch = ops::XlaLaunch(
- scope.WithOpName("launch0"), std::initializer_list<Input>{},
- std::initializer_list<Input>{a, b, c, d},
+ scope.WithOpName("launch0").WithDevice("/gpu:0"),
+ std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
std::initializer_list<Input>{u, v, w},
DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 3770eea6d0..085c0e5adb 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -55,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
- BuildXlaLaunchOpsPass);
+ BuildXlaOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2547..26cb3af9d6 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@ package(
)
cc_library(
- name = "xla_launch_op",
- srcs = ["xla_launch_op.cc"],
- hdrs = ["xla_launch_op.h"],
+ name = "xla_ops",
+ srcs = ["xla_ops.cc"],
+ hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f632f7..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
- if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = se::host::kHostPlatformId;
- } else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = ctx->device()
- ->tensorflow_gpu_device_info()
- ->stream->parent()
- ->platform()
- ->id();
- } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
- use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
- platform_id_ = xla_device_metadata_->platform()->id();
- }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
- if (xla_device_metadata_) {
- *cache = new XlaCompilationCache(xla_device_metadata_->client(),
- xla_device_metadata_->jit_device_type());
- return Status::OK();
- }
-
- auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
- if (!platform.ok()) {
- return platform.status();
- }
- xla::LocalClientOptions client_options;
- client_options.set_platform(platform.ValueOrDie());
- client_options.set_intra_op_parallelism_threads(
- ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
- auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
- if (!client.ok()) {
- return client.status();
- }
- const XlaOpRegistry::DeviceRegistration* registration;
- if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
- &registration)) {
- return errors::InvalidArgument("No JIT device registered for ",
- device_type_.type());
- }
- *cache = new XlaCompilationCache(
- client.ValueOrDie(), DeviceType(registration->compilation_device_name));
- return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
- << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
- // We store information about the JIT-compiled XLA computation
- // in the ResourceMgr.
- ResourceMgr* rm = ctx->resource_manager();
- OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
- se::Stream* stream =
- ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
- XlaCompilationCache* cache;
- OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
- rm->default_container(), "xla_cache", &cache,
- [this, ctx](XlaCompilationCache** cache) {
- return BuildCompilationCache(ctx, cache);
- }));
- // Hold the reference to the JIT during evaluation. (We could probably
- // free it sooner because the ResourceMgr will retain a reference, but
- // this is more obviously correct.)
- core::ScopedUnref cache_ref(cache);
-
- std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
-
- xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
- XlaAllocator local_xla_allocator(client->backend().platform(),
- ctx->device()->GetAllocator({}));
- xla::DeviceMemoryAllocator* xla_allocator;
- // If we are on an XlaDevice, use the underlying XLA platform's allocator
- // directly. We could use the StreamExecutor's allocator which may
- // theoretically be more correct, but XLA returns a nice OOM message in a
- // Status and StreamExecutor does not.
- //
- // Importantly we can't use ctx->device()->GetAllocator() as the allocator
- // (which local_xla_allocator above uses) as on an XlaDevice, this is a
- // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
- // real allocator to allocate real buffers.
- if (xla_device_metadata_) {
- xla_allocator = client->backend().memory_allocator();
- } else {
- xla_allocator = &local_xla_allocator;
- }
-
- XlaCompiler::Options options;
- options.client = client;
- if (ctx->op_device_context() != nullptr) {
- options.device_ordinal =
- ctx->op_device_context()->stream()->parent()->device_ordinal();
- }
- options.device_type = cache->device_type();
- options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
- options.device_allocator = xla_allocator;
- if (xla_device_metadata_) {
- options.shape_representation_fn =
- xla_device_metadata_->shape_representation_fn();
- }
-
- const XlaCompiler::CompilationResult* kernel;
- xla::LocalExecutable* executable;
-
- std::map<int, Tensor> constant_args;
- for (int i : constants_) {
- constant_args.insert({i, ctx->input(i)});
- }
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
- // Optimization: where possible, have the computation return a naked array
- // rather than a one-element tuple.
- compile_options.always_return_tuple = false;
-
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, compile_options));
-
- VLOG(1) << "Executing XLA Computation...";
-
- XlaComputationLaunchContext launch_context(
- client, xla_allocator,
- /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
- use_multiple_streams_);
- launch_context.PopulateInputs(ctx, kernel, variables);
-
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(GetXLARandomSeed());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
-
- auto run_result = executable->Run(launch_context.arguments(), run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
-
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
- VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
- VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
- .Device(DEVICE_GPU)
- .HostMemory("constants")
- .HostMemory("resources"),
- XlaLocalLaunchOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9817..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_ = nullptr;
- bool use_multiple_streams_ = false;
- const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
- explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
- ~XlaLocalLaunchOp() override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000000..accc86a86d
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,500 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+ XlaPlatformInfo* result) {
+ DeviceType device_type = ctx->device_type();
+ se::Platform::Id platform_id = nullptr;
+ const XlaDevice::Metadata* xla_device_metadata = nullptr;
+ std::unique_ptr<XlaAllocator> xla_allocator;
+ xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ platform_id = se::host::kHostPlatformId;
+ } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+ platform_id = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+ // If we are on an XlaDevice, use the underlying XLA platform's allocator
+ // directly. We could use the StreamExecutor's allocator which may
+ // theoretically be more correct, but XLA returns a nice OOM message in a
+ // Status and StreamExecutor does not.
+ //
+ // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+ // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+ // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+ // allocator to allocate real buffers.
+
+ platform_id = xla_device_metadata->platform()->id();
+ device_allocator =
+ xla_device_metadata->client()->backend().memory_allocator();
+ }
+
+ if (!device_allocator) {
+ TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+ se::MultiPlatformManager::PlatformWithId(platform_id));
+ xla_allocator = absl::make_unique<XlaAllocator>(
+ platform, ctx->device()->GetAllocator({}));
+ }
+
+ *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+ std::move(xla_allocator), device_allocator);
+
+ return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+ explicit XlaExecutableClosure(
+ xla::LocalClient* client, xla::LocalExecutable* executable,
+ const XlaCompiler::CompilationResult* compilation_result,
+ std::map<int, OptionalTensor> resource_var_snapshots,
+ int num_constant_args)
+ : client_(client),
+ executable_(executable),
+ compilation_result_(compilation_result),
+ resource_var_snapshots_(std::move(resource_var_snapshots)),
+ num_constant_args_(num_constant_args) {}
+
+ XlaExecutableClosure(XlaExecutableClosure&&) = default;
+ XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+ xla::LocalClient* client() const { return client_; }
+ xla::LocalExecutable* executable() const { return executable_; }
+ const XlaCompiler::CompilationResult* compilation_result() const {
+ return compilation_result_;
+ }
+ const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+ return resource_var_snapshots_;
+ }
+ int num_constant_args() const { return num_constant_args_; }
+
+ private:
+ xla::LocalClient* client_;
+ xla::LocalExecutable* executable_;
+ const XlaCompiler::CompilationResult* compilation_result_;
+ std::map<int, OptionalTensor> resource_var_snapshots_;
+ int num_constant_args_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+ XlaExecutableClosureStore() : key_counter_(0) {}
+
+ using KeyT = string;
+
+ KeyT Produce(XlaExecutableClosure result) {
+ mutex_lock l(mutex_);
+ KeyT key = absl::StrCat(key_counter_++);
+ bool insert_successful = closures_.emplace(key, std::move(result)).second;
+ DCHECK(insert_successful);
+ (void)insert_successful;
+ return key;
+ }
+
+ XlaExecutableClosure Consume(const KeyT& key) {
+ mutex_lock l(mutex_);
+ auto it = closures_.find(key);
+ DCHECK(it != closures_.end());
+ XlaExecutableClosure value = std::move(it->second);
+ closures_.erase(it);
+ return value;
+ }
+
+ static XlaExecutableClosureStore* Global() {
+ static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+ return instance;
+ }
+
+ private:
+ mutex mutex_;
+ int64 key_counter_ GUARDED_BY(mutex_);
+ absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+} // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ function_(function) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+ const XlaPlatformInfo& platform_info,
+ XlaCompilationCache** cache) {
+ if (platform_info.xla_device_metadata()) {
+ *cache = new XlaCompilationCache(
+ platform_info.xla_device_metadata()->client(),
+ platform_info.xla_device_metadata()->jit_device_type());
+ return Status::OK();
+ }
+
+ auto platform =
+ se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+ if (!platform.ok()) {
+ return platform.status();
+ }
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+ if (!client.ok()) {
+ return client.status();
+ }
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+ &registration)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ platform_info.device_type().type());
+ }
+ *cache = new XlaCompilationCache(
+ client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+ return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+ OpKernelContext* ctx, const NameAttrList& function,
+ const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+ absl::Span<const int> constants, xla::LocalClient** client,
+ std::map<int, OptionalTensor>* variables,
+ const XlaCompiler::CompilationResult** kernel,
+ xla::LocalExecutable** executable) {
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ if (!rm) {
+ return errors::Internal("No resource manager.");
+ }
+
+ XlaCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_cache", &cache,
+ [&](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, platform_info, cache);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref cache_ref(cache);
+
+ *variables = SnapshotResourceVariables(ctx, resources);
+ *client = static_cast<xla::LocalClient*>(cache->client());
+
+ XlaCompiler::Options options;
+ options.client = *client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
+ options.device_type = cache->device_type();
+ options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ options.graph_def_version = ctx->function_library()->graph_def_version();
+ options.allow_cpu_custom_calls =
+ (platform_info.platform_id() == se::host::kHostPlatformId);
+ options.device_allocator = platform_info.allocator();
+ if (platform_info.xla_device_metadata()) {
+ options.shape_representation_fn =
+ platform_info.xla_device_metadata()->shape_representation_fn();
+ }
+
+ std::map<int, Tensor> constant_args;
+ for (int i : constants) {
+ constant_args.insert({i, ctx->input(i)});
+ }
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.is_entry_computation = true;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
+ return cache->Compile(options, function, constant_args, *variables, ctx,
+ compile_options, kernel, executable);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+ << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ XlaComputationLaunchContext launch_context(
+ client, platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, kernel, variables,
+ /*missing_ctx_input_prefix=*/0);
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
+ VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx),
+ constants_(ConstantsVector(ctx)),
+ resources_(ResourcesVector(ctx)),
+ function_(FunctionAttr(ctx)) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+ // if it didn't have to compile the cluster because of a compilation-cache
+ // hit. This is because we at least need new snapshots of the resource
+ // variables.
+ XlaExecutableClosureStore::KeyT key =
+ XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+ client, executable, kernel, std::move(variables), constants_.size()));
+
+ Allocator* cpu_allocator = [&] {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ return ctx->device()->GetAllocator(host_alloc_attrs);
+ }();
+
+ Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+ compilation_key.flat<string>()(0) = key;
+
+ Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+ compilation_successful.flat<bool>()(0) = true;
+
+ ctx->set_output(0, compilation_key);
+ ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+ Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+ const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+ XlaExecutableClosure closure =
+ XlaExecutableClosureStore::Global()->Consume(key);
+
+ XlaComputationLaunchContext launch_context(
+ closure.client(), platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+
+ // We're missing the must-be-constant inputs, tell `PopulateInputs`
+ // about this. We don't actually need these inputs because they've
+ // already been baked into the compiled kernel.
+ launch_context.PopulateInputs(
+ ctx, closure.compilation_result(), closure.resource_var_snapshots(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args());
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result =
+ closure.executable()->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+ OP_REQUIRES_OK(
+ ctx,
+ launch_context.PopulateOutputs(
+ ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000000..489d26eb30
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+ XlaPlatformInfo() : device_type_("") {}
+ explicit XlaPlatformInfo(const DeviceType device_type,
+ se::Platform::Id platform_id,
+ const XlaDevice::Metadata* xla_device_metadata,
+ std::unique_ptr<XlaAllocator> xla_allocator,
+ xla::DeviceMemoryAllocator* device_allocator)
+ : device_type_(device_type),
+ platform_id_(platform_id),
+ xla_device_metadata_(xla_device_metadata),
+ xla_allocator_(std::move(xla_allocator)),
+ device_allocator_(device_allocator) {
+ CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+ }
+
+ XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+ bool UseMultipleStreams() const {
+ return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+ }
+
+ xla::DeviceMemoryAllocator* allocator() const {
+ return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+ }
+ DeviceType device_type() const { return device_type_; }
+
+ // This is equal to xla_device_metadata()->platform()->id() if
+ // xla_device_metadata() is not nullptr.
+ se::Platform::Id platform_id() const { return platform_id_; }
+
+ // This may be null if the op this XlaPlatformInfo is for was not placed on an
+ // XLA device.
+ const XlaDevice::Metadata* xla_device_metadata() const {
+ return xla_device_metadata_;
+ }
+ bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+ DeviceType device_type_;
+ se::Platform::Id platform_id_;
+
+ // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+ // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+ // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+ const XlaDevice::Metadata* xla_device_metadata_;
+
+ // If the op associated with this XlaPlatformInfo is placed on an XLA device
+ // then device_allocator_ is the xla::Backend's memory allocator and
+ // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
+ // then device_allocator_ is null and xla_allocator_ points to an appropriate
+ // XlaAllocator instance.
+ std::unique_ptr<XlaAllocator> xla_allocator_;
+ xla::DeviceMemoryAllocator* device_allocator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+ XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+ explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+
+ XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+ explicit XlaRunOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ XlaPlatformInfo platform_info_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index e6cc6e52ae..4f0c370e65 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) {
return elementwise_ops->count(node.op()) > 0;
}
+// Nodes that XLA can compile are put in `candidates`. Nodes put in
+// `isolated_nodes` must either be unclustered or be put in trivial single-node
+// clusters.
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates) {
+ OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -411,6 +414,8 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
+ VLOG(4) << "Device type for " << node->name() << ": "
+ << device_type.type_string();
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
// is_compilable_fn has already logged the reason if it returned false.
@@ -439,19 +444,56 @@ Status FindCompilationCandidates(
<< node->type_string();
continue;
}
- if (compile_time_const_nodes[node->id()] &&
- !registration->requires_compilation) {
+ if (compile_time_const_nodes[node->id()]) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
- // We need to be able to constant fold the nodes in
- // compile_time_const_nodes given constant inputs (required by XLA) and
- // therefore can't auto-cluster stateful ops since these can never be
- // constant folded.
- VLOG(2) << "Rejecting " << node->name()
- << ": must-be-constant stateful op";
- continue;
+ // It is easiest to demonstrate the problem we're trying to solve with
+ // an example. Say we have this graph:
+ //
+ // shape = RandomUniformInt();
+ // reshape = Reshape(input, shape)
+ //
+ // Both RandomUniformInt and Reshape are compilable by XLA so, absent
+ // any other reason, we will try to put both shape and reshape in the
+ // same cluster. However, since XLA only supports statically shaped
+ // values, it will expect to be able to constant fold `shape` to get a
+ // static shape for `reshape`. This is a problem because side-effecting
+ // ops like RandomUniformInt() cannot be constant folded. We fix this
+ // by putting `shape` and `reshape` in different clusters, which results
+ // in us recompiling `reshape`'s cluster for every new value of `shape`,
+ // making `reshape` statically sized within each compilation. We
+ // simplify the solution even further by disallowing operations like
+ // `shape` from being part of *any* non-trivial cluster. They're either
+ // not compiled by XLA altogether or, if assigned to an XLA_* device
+ // with "must compile" semantics, compiled into a trivial single-op
+ // cluster. This approach leaves some room for improvement, and we can
+ // consider implementing a more aggressive data-flow-analysis based
+ // solution in the future if needed.
+ //
+ // One ugly problem we have to contend with: certain sets of ops *have*
+ // to be in the same cluster because values flowing between them have
+ // types that can't be live-in or live-out of a cluster. These ops are:
+ //
+ // - TensorArray ops operating on the same TensorArray instance.
+ // - Stack ops operating on the same Stack instance.
+ //
+ // To work around this we avoid isolating these specific ops. Because
+ // of this concession it is unsound to auto-cluster them because then
+ // we'd create clusters we could not compile (because we can't constant
+ // fold, say, a TensorArrayRead or a StackPopV2). But we don't
+ // auto-cluster these operations today so we're good for now.
+ const XlaResourceOpInfo* op_info =
+ GetResourceOpInfoForOp(node->type_string());
+ bool is_tensor_array_or_stack_op =
+ op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+ if (!is_tensor_array_or_stack_op) {
+ VLOG(2) << "Isolating " << node->name()
+ << ": must-be-constant stateful op";
+ isolated_nodes->insert(node);
+ // Keep going and execute all the other checks.
+ }
}
}
// We don't auto-cluster functional control flow nodes containing resource
@@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
+ absl::flat_hash_set<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
: Env::Default(),
- is_compilable_fn, &compilation_candidates));
+ is_compilable_fn, &compilation_candidates, &isolated_nodes));
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
@@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
+
+ if (isolated_nodes.count(node_from)) {
+ continue;
+ }
+
string from_scope;
string to_scope;
for (int to : cycles.Successors(from)) {
@@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl(
node_to->assigned_device_name()) {
continue;
}
+ if (isolated_nodes.count(node_to)) {
+ continue;
+ }
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
// edge. This restriction is overridden if the global_jit_level is ON. If
@@ -931,6 +982,11 @@ Status MarkForCompilationPass::RunImpl(
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
+ if (flags->tf_xla_clustering_debug) {
+ dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph,
+ options.flib_def);
+ }
+
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a4c8..2a80c745e3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
@@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
return ids;
}
-gtl::FlatMap<string, std::vector<string>> GetClusterSets(
+absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
const Graph& g, std::vector<string>* cluster_names = nullptr) {
CHECK(cluster_names == nullptr || cluster_names->empty());
- gtl::FlatMap<string, std::vector<string>> cluster_sets;
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets;
for (const auto& p : GetClusters(g)) {
cluster_sets[p.second].push_back(p.first);
}
@@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
@@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW",
@@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::vector<string> cluster_names;
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph, &cluster_names);
ASSERT_EQ(cluster_sets.size(), 2);
@@ -894,5 +895,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) {
EXPECT_EQ(clusters["fn_call"], "");
}
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape =
+ ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+ ops::Const(root.WithOpName("test/minval"), 1),
+ ops::Const(root.WithOpName("test/maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/shape_rng"], "");
+ EXPECT_NE(clusters["test/reshape"], "");
+ EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+ DT_INT32);
+ Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+ ops::TensorArrayWrite tensor_array_write(
+ root.WithOpName("test/write"), tensor_array.handle, zero,
+ ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+ Output tensor_array_read =
+ ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+ zero, tensor_array_write.flow_out, DT_INT32);
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"),
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+ tensor_array_read);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/read"], "");
+ EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index 65669877f7..d56d0f8ccf 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,18 +14,35 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
SessionOptions* session_options) {
- // Assign all nodes to the CPU device.
+ // Assign all unassigned nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
+ // Call AddDevices to register the XLA devices.
+ //
+ // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
+ // make this more direct, but probably not worth it solely for this test.
+ std::vector<Device*> devices;
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
+
+ auto delete_devices = gtl::MakeCleanup([&] {
+ for (Device* d : devices) {
+ delete d;
+ }
+ });
+
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index 13804c6a05..f72224545b 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -4,9 +4,17 @@ package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
+
+tf_gen_op_wrapper_py(
+ name = "xla_ops_wrapper_py",
+ out = "xla_ops.py",
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3caab..bcd1a29b1f 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -51,4 +51,43 @@ REGISTER_OP("XlaClusterOutput")
"Operator that connects the output of an XLA computation to other "
"consumer graph nodes.");
+REGISTER_OP("_XlaCompile")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Input("resources: Nresources * resource")
+ .Attr("Nresources: int >= 0")
+ .Output("key: string")
+ .Output("compilation_successful: bool")
+ .Attr("function: func")
+ // The compilation cache is stateful.
+ .SetIsStateful()
+ .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+ node and associated metadata.
+
+compilation_successful: True iff the compilation was successful. Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Input("key: string")
+ // XLA random-number generation ops are stateful.
+ // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+ .SetIsStateful()
+ .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 10fc9e85d9..b1f9e9088f 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -15,17 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace tensorflow {
namespace {
-Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+Status FindNodesToDecluster(const Graph& graph,
+ absl::flat_hash_set<Node*>* result,
absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
@@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
- gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ absl::flat_hash_set<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 35872daa65..0feb73a89e 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -60,9 +60,9 @@ class FakeBinaryOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { CHECK(false); }
};
-class FakeResourceVarUpdateOp : public OpKernel {
+class FakeResourceUpdateOp : public OpKernel {
public:
- explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+ explicit FakeResourceUpdateOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { CHECK(false); }
@@ -74,10 +74,9 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary")
.HostMemory("host_out"),
FakeBinaryOp);
-REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
- .Device(DEVICE_CPU)
- .HostMemory("something_else"),
- FakeResourceVarUpdateOp);
+REGISTER_KERNEL_BUILDER(
+ Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
+ FakeResourceUpdateOp);
Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get());
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 56e35c0059..e039d46ec8 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@@ -89,8 +90,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) {
// point.
class ResourceOpSet {
private:
- using Impl = gtl::FlatSet<ResourceOp>;
+ using Impl = absl::flat_hash_set<ResourceOp>;
public:
ResourceOpSet() = default;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 3aa9e9c7ed..0471995015 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -228,37 +228,38 @@ Status XlaCompilationCache::Compile(
const XlaCompiler::Options& options, const NameAttrList& function,
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options) {
+ const XlaCompiler::CompileOptions& compile_options,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable) {
return CompileImpl(options, function, constant_args, variable_args, ctx,
- compilation_result, executable, compile_options, false);
+ compile_options, /*compile_single_op=*/false,
+ out_compilation_result, out_executable);
}
Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options,
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options) {
+ const XlaCompiler::CompileOptions& compile_options,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable) {
const NodeDef& def = ctx->op_kernel().def();
NameAttrList name;
name.set_name(def.op());
*name.mutable_attr() = def.attr();
- return CompileImpl(options, name, constant_args, variable_args, ctx,
- compilation_result, executable, compile_options, true);
+ return CompileImpl(
+ options, name, constant_args, variable_args, ctx, compile_options,
+ /*compile_single_op=*/true, out_compilation_result, out_executable);
}
Status XlaCompilationCache::CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function,
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options,
- bool compile_single_op) {
- CHECK_NE(executable, nullptr);
+ const XlaCompiler::CompileOptions& compile_options, bool compile_single_op,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable) {
+ DCHECK_NE(out_executable, nullptr);
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
@@ -357,8 +358,8 @@ Status XlaCompilationCache::CompileImpl(
}
}
TF_RETURN_IF_ERROR(entry->compilation_status);
- *compilation_result = &entry->compilation_result;
- *executable = entry->executable.get();
+ *out_compilation_result = &entry->compilation_result;
+ *out_executable = entry->executable.get();
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 10ad87e38c..75c7758f73 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -68,9 +68,9 @@ class XlaCompilationCache : public ResourceBase {
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options);
+ const XlaCompiler::CompileOptions& compile_options,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
@@ -78,9 +78,9 @@ class XlaCompilationCache : public ResourceBase {
const XlaCompiler::Options& options,
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options);
+ const XlaCompiler::CompileOptions& compile_options,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable);
xla::LocalClient* client() const { return client_; }
const DeviceType& device_type() const { return device_type_; }
@@ -89,15 +89,14 @@ class XlaCompilationCache : public ResourceBase {
private:
// Common implementation of Compile and CompileSingleOp.
- Status CompileImpl(const XlaCompiler::Options& options,
- const NameAttrList& function,
- const std::map<int, Tensor>& constant_args,
- const std::map<int, OptionalTensor>& variable_args,
- OpKernelContext* ctx,
- const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions& compile_options,
- bool compile_single_op);
+ Status CompileImpl(
+ const XlaCompiler::Options& options, const NameAttrList& function,
+ const std::map<int, Tensor>& constant_args,
+ const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
+ const XlaCompiler::CompileOptions& compile_options,
+ bool compile_single_op,
+ const XlaCompiler::CompilationResult** out_compilation_result,
+ xla::LocalExecutable** out_executable);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
// XLA computation already, and generates an XLA LocalExecutable `executable`.
@@ -152,7 +151,7 @@ class XlaCompilationCache : public ResourceBase {
};
mutex compile_cache_mu_;
- gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
+ absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
struct CompileStats {
@@ -165,7 +164,7 @@ class XlaCompilationCache : public ResourceBase {
mutex compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster.
- gtl::FlatMap<string, CompileStats> compile_stats_
+ absl::flat_hash_map<string, CompileStats> compile_stats_
GUARDED_BY(compile_stats_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 3ba48e8c31..79976c85df 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -34,6 +34,7 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
@@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, result, variables);
+ launch_context.PopulateInputs(ctx, result, variables,
+ /*missing_ctx_input_prefix=*/0);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
TF_RETURN_IF_ERROR(run_result.status());
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
- ctx, result, run_result.ConsumeValueOrDie()));
+ ctx, result, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
return Status::OK();
}
@@ -177,7 +180,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
- result, executable, compile_options);
+ compile_options, result, executable);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7e159e3171..003c1d8081 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
@@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797def04..0824c4644e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -373,7 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
- TracingDevice::Compute(op_kernel, context);
+ op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
@@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
return status;
}
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+ mutex_lock lock(mu_);
+ sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+ mutex_lock lock(mu_);
+ return sync_on_completion_;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ffa8c..0f06b3fc80 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
+ // Instructs this XlaDevice to return 'sync_on_completion' for
+ // RequiresSyncOnCompletion().
+ void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+ bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
- mutex mu_;
+ mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ // True if the device requires XlaDevice::Sync to be called on completion
+ // regardless of status.
+ bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582682..6967ad1f03 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.HostMemory("resources"), \
KERNEL);
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \
+ .Device(DEVICE) \
+ .HostMemory("constants") \
+ .HostMemory("resources"), \
+ KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
+
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
@@ -90,9 +100,15 @@ class XlaAssignVariableOp : public AsyncOpKernel {
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
ResourceHandleOp<Var>); \
REGISTER_KERNEL_BUILDER( \
+ Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \
+ ResourceHandlesOp<Var>); \
+ REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
REGISTER_KERNEL_BUILDER( \
+ Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \
+ ReadVariablesOp); \
+ REGISTER_KERNEL_BUILDER( \
Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \
DestroyResourceOp); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ef4466f005..60979556a3 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
- DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559674..19e681af0c 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@ limitations under the License.
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+ kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index affeab4a8c..4f6fc4e068 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,13 +42,14 @@ using xla::ShapedBuffer;
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+ OpKernelContext* ctx, absl::Span<const int> variables) {
std::map<int, OptionalTensor> snapshot;
for (int i : variables) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
tensor.present = true;
@@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables) {
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
@@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs(
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
+ DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
- t = &(ctx->input(arg_num));
+ t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
if (use_multiple_streams_) {
@@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs(
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- ScopedShapedBuffer output) {
+ ScopedShapedBuffer output, int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -275,6 +278,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
+ TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
+ << "Invalid input for outputs " << i;
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
@@ -313,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ int actual_input_index = write.input_index - missing_ctx_input_prefix;
+ if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
@@ -323,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index), &variable,
+ ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275fab8..326d70a027 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class XlaAllocator;
@@ -43,7 +44,7 @@ class XlaAllocator;
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+ OpKernelContext* ctx, absl::Span<const int> variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -88,14 +89,24 @@ class XlaComputationLaunchContext {
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly. All elements in kernel's
+ // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
+ // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables);
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ xla::ScopedShapedBuffer output,
+ int missing_ctx_input_prefix);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 97ed554171..ba2401ed26 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -895,6 +895,22 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "tensor_list_ops_test",
+ size = "small",
+ srcs = ["tensor_list_ops_test.py"],
+ # TensorList ops are not implemented in the on-demand compilation model yet.
+ disabled_backends = "cpu_ondemand",
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
@@ -978,7 +994,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1029,6 +1045,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "permute_test",
+ size = "small",
+ srcs = ["permute_test.py"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:nn_ops",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
@@ -1105,6 +1134,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -1198,6 +1228,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "quantized_ops_test",
+ size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342787..68f52e796c 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
- minmax_types = set(self.numeric_types) - set(self.complex_types)
+ minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
- for output_type in self.int_types:
+ for output_type in minmax_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 17280e445b..1b39d53dc0 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
equality_test=self.ListsAreClose)
def testIntOps(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testBinary(
gen_math_ops.truncate_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
@@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
- if dtype not in self.complex_types: # min/max not supported for complex
+ # min/max not supported for complex
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
@@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([[70], [14]], dtype=dtype))
# Complex support for squared_difference is incidental, see b/68205550
- if dtype not in self.complex_types:
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
@@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
+ if dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+ self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
+
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops.floor_div,
@@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testDivision(dtype)
def testFloatDivision(self):
@@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, 1, -1, 0], dtype=dtype))
def testIntRemainder(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types - {np.int8}:
self._testRemainder(dtype)
def testFloatRemainder(self):
@@ -1437,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([4, 0], dtype=np.int32),
expected=np.zeros([4, 0], dtype=dtype))
+ x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array((3, 7, 8, 9), dtype=np.int32),
+ expected=np.tile(x, (1, 7, 8, 9)))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index a76f136736..1d3979b21b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -2,6 +2,10 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
def all_backends():
b = ["cpu"] + plugins.keys()
@@ -58,14 +62,14 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
- backend_tags += ["requires-gpu-sm35"]
+ backend_tags += tf_cuda_tests_tags()
elif backend in plugins:
backend_args += [
"--test_device=" + plugins[backend]["device"],
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2d8f..9390870e07 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def XlaLaunchOpCount(labels):
- """Count how many XlaLaunch labels are present."""
- return sum("XlaLaunch(" in x for x in labels)
+class DenseLayerTest(test.TestCase):
+ def countXlaOps(self, labels):
+ """Count how many XlaCompile/XlaRun labels are present."""
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+ return xla_run_count
-class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+ auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- XlaLaunch op by XLA.
+ XlaCompile/XlaRun op pair by XLA.
"""
with self.cached_session() as sess:
@@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
@@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase):
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+ pairs.
"""
with self.cached_session() as sess:
@@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(2, XlaLaunchOpCount(labels))
+ self.assertEqual(2, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 8c018cccb8..374942a0b3 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
+DATA_FORMATS = (
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+)
+
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
@@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testInference(self, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
@@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearning(self, data_format):
self._testLearning(False, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearningWithGradientChecker(self, data_format):
self._testLearning(True, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientTraining(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
@@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientInference(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95daab..a38e1edafe 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase):
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase):
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase):
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 6fe5a66e0e..68fdb5caf4 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -605,10 +605,6 @@ class ResizeBilinearTest(xla_test.XLATestCase):
class NonMaxSuppressionTest(xla_test.XLATestCase):
def testNMS128From1024(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -644,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(indices_tf.size, max_output_size)
def testNMS3From6Boxes(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
# Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -693,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
@@ -736,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0])
+ def testNMS3Then1WithScoreMaxThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+ # One is filtered out by max_output_size.
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 1
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.cached_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 1)
+ self.assertAllClose(indices_tf[:num_valid], [3])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb123e..de68ff0e32 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+ """Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
- # actually ran. However, it is sometimes possible for XlaLaunch ops to be
- # constant-folded away, so the check is optional.
+ #
+ # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+ # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+ # ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
@@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase):
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase):
y = math_ops.add(x, x)
return y, y
- # Exercises compling a function (say, Foo) which calls another
- # function (say, Bar) which is not inlined. When the compiler compiles
- # Foo, it needs to symbolic execute Bar correctly regardless whether
- # Bar is inlined or not.
+ # Exercises compiling a function (say, Foo) which calls another function
+ # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+ # to symbolically execute Bar correctly regardless of whether Bar is inlined
+ # or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase):
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
@@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
@@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
@@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
@@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaCompile"))
+ self.assertFalse(InLabels(labels, "XlaRun"))
- # Compile the backprop. One XlaLaunch.
+ # Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaCompile"))
+ self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("XlaLaunch(" in x for x in labels)
- return output, count
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+
+ return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py
index 43c469d032..73b3638e80 100644
--- a/tensorflow/compiler/tests/lstm.py
+++ b/tensorflow/compiler/tests/lstm.py
@@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
def RandomVar(shape, name=None):
"""Returns a variable of the given shape initialized to random values."""
- return variables.Variable(
+ return variables.VariableV1(
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index f985c5d2d9..38cb2f83ef 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase):
output.run()
def testConstants(self):
- constants = [
- np.float32(42),
- np.array([], dtype=np.float32),
- np.array([1, 2], dtype=np.float32),
- np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
- np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
- dtype=np.float32),
- np.array([[[]], [[]]], dtype=np.float32),
- np.array([[[[1]]]], dtype=np.float32),
- ]
- for c in constants:
- self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+ for dtype in self.numeric_types:
+ constants = [
+ dtype(42),
+ np.array([], dtype=dtype),
+ np.array([1, 2], dtype=dtype),
+ np.array([7, 7, 7, 7, 7], dtype=dtype),
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+
+ def testComplexConstants(self):
+ for dtype in self.complex_types:
+ constants = [
+ dtype(42 + 3j),
+ np.array([], dtype=dtype),
+ np.ones([50], dtype=dtype) * (3 + 4j),
+ np.array([1j, 2 + 1j], dtype=dtype),
+ np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4 + 6j], [5, 6]],
+ [[10 + 7j, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1 + 3j]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py
new file mode 100644
index 0000000000..dbb9274df4
--- /dev/null
+++ b/tensorflow/compiler/tests/permute_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the DataFormatVecPermute operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+class XlaPermuteOpTest(xla_test.XLATestCase):
+
+ def _runPermuteAndCompare(self, x, src_format, dst_format, expected):
+ with self.cached_session() as session:
+ with self.test_scope():
+ placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape)
+ param = {placeholder: x}
+ output = nn_ops.data_format_vec_permute(
+ placeholder, src_format=src_format, dst_format=dst_format)
+ result = session.run(output, param)
+ self.assertAllEqual(result, expected)
+
+ def testNHWCToNCHW(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
+
+ def testNCHWToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
+
+ def testNHWCToHWNC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3])
+
+ def testNHWCToNCHW2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW",
+ [[7, 4], [5, 1], [9, 3], [4, 5]])
+
+ def testNHWCToHWNC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC",
+ [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC",
+ [[4, 5], [7, 4], [9, 3], [5, 1]])
+
+ def testNCHWToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC",
+ [[7, 4], [4, 5], [5, 1], [9, 3]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000000..80c338513b
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 6e18344117..36ef6ed5fe 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
- return set(self.numeric_types) - set(self.complex_types)
+ return set(self.numeric_types) - set(
+ self.complex_types) - {np.uint8, np.int8}
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
@@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
@@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
@@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
- # TODO(b/26783907): this test requires the CPU backend to implement sort.
- if self.device in ["XLA_CPU"]:
- return
with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index bddda6f302..dc119fb0f8 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include <random>
#include <unordered_map>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -63,7 +64,6 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
- gtl::FlatSet<float> already_generated;
+ absl::flat_hash_set<float> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<float>(&tensor, [&](int i) -> float {
float generated;
@@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_DOUBLE: {
- gtl::FlatSet<double> already_generated;
+ absl::flat_hash_set<double> already_generated;
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
test::FillFn<double>(&tensor, [&](int i) -> double {
double generated;
@@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_COMPLEX64: {
- gtl::FlatSet<std::pair<float, float>> already_generated;
+ absl::flat_hash_set<std::pair<float, float>> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<complex64>(&tensor, [&](int i) {
complex64 generated;
@@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT32: {
- gtl::FlatSet<int32> already_generated;
+ absl::flat_hash_set<int32> already_generated;
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
test::FillFn<int32>(&tensor, [&](int i) -> int32 {
int32 generated;
@@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT64: {
- gtl::FlatSet<int64> already_generated;
+ absl::flat_hash_set<int64> already_generated;
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
1LL << 40);
test::FillFn<int64>(&tensor, [&](int i) -> int64 {
@@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_BOOL: {
- gtl::FlatSet<bool> already_generated;
+ absl::flat_hash_set<bool> already_generated;
std::bernoulli_distribution distribution;
test::FillFn<bool>(&tensor, [&](int i) -> bool {
bool generated;
@@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) {
do {
dims = RandomDims(1);
size = TensorShape(dims).num_elements();
- } while (size * size < tf_xla_max_tensor_size);
+ } while (size * size > tf_xla_max_tensor_size);
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
});
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 60c2337743..abc822ef36 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
def testSeqLength(self):
for dtype in self.all_types:
- for seq_dtype in self.int_types:
+ for seq_dtype in self.all_types & {np.int32, np.int64}:
self._testBasic(dtype, seq_dtype)
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 51c04b5c47..57f0ab7a9e 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,22 +48,30 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
- supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
- def testTopK(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
+ def testKeyValueSort(self):
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for key_type in supported_types.intersection(self.numeric_types):
+ for value_type in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=key_type)
+ np.random.shuffle(x)
+ y = (-x).astype(value_type)
+ self._assertOpOutputMatchesExpected(
+ xla.key_value_sort, [x, y],
+ expected=[
+ np.arange(101, dtype=key_type),
+ -np.arange(101, dtype=value_type)
+ ])
+ def testTopK(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -89,10 +97,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -122,10 +126,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
@@ -144,10 +144,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9355..e8741bc468 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
- x = stateless.stateless_random_uniform(
+ x = stateless.stateless_random_normal(
shape=[10000], seed=seed_t, dtype=dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue(np.all(np.isfinite(y)))
@@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py
new file mode 100644
index 0000000000..5c079d595c
--- /dev/null
+++ b/tensorflow/compiler/tests/tensor_list_ops_test.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ops which manipulate lists of tensors via bridge."""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+def scalar_shape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+class ListOpsTest(xla_test.XLATestCase):
+
+ def testElementShape(self):
+ with self.cached_session() as sess, self.test_scope():
+ dim = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=(dim, 15), num_elements=20,
+ element_dtype=dtypes.float32)
+ e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
+ e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
+ self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
+ self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
+
+ def testPushPop(self):
+ with self.cached_session() as sess, self.test_scope():
+ num = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(1.0, shape=(7, 15)))
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(2.0, shape=(7, 15)))
+ l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15)))
+ self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
+
+ def testPushPopSeparateLists(self):
+ with self.cached_session() as sess, self.test_scope():
+ num = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=scalar_shape(),
+ num_elements=num,
+ element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
+ l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
+ l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
+ _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
+ l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
+ l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
+ l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
+ result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20})
+ self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
+
+ def testEmptyTensorList(self):
+ dim = 7
+ with self.cached_session() as sess, self.test_scope():
+ p = array_ops.placeholder(dtypes.int32)
+ l = list_ops.empty_tensor_list(
+ element_shape=(p, 15), element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(1.0, shape=(dim, 15)))
+ _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Use TensorListReserve instead"):
+ self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 55a992195f..98a07709c6 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase):
expected=np.array([[2], [5]], dtype=dtype))
def testClipByValue(self):
- # TODO(b/78258593): enable integer types here too.
- for dtype in self.float_types:
+ for dtype in self.numeric_types - self.complex_types:
test_cases = [
(np.array([2, 4, 5], dtype=dtype), dtype(7)), #
(dtype(1), np.array([2, 4, 5], dtype=dtype)), #
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 5b0e57f83f..77f6eee0cf 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
self.assertAllClose(result[i], expected[i], rtol, atol)
def testAllTypeOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
@@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
- # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
- if dtype == np.float16 and self.device == "XLA_CPU":
- continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
@@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
def testNumericOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 1e600c44e9..4cf88fc523 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dtype=dtype))
def testNeg(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.uint8, np.int8}:
self._assertOpOutputMatchesExpected(
xla.neg,
args=(np.array([1, 2, 3], dtype=dtype),),
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 88827cb53b..98a41981cf 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,10 +97,23 @@ class XLATestCase(test.TestCase):
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
-
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
+
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+ self.signed_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if not dtype.is_unsigned)
+ self.unsigned_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if dtype.is_unsigned)
self._float_types = set(
[dtype.as_numpy_dtype for dtype in self._float_tf_types])
self.complex_types = set([
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ba1e3b2b4f..3f631f91ec 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -635,6 +635,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
@@ -649,6 +650,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..adcdb6c8f7 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -30,14 +30,15 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
- out_ops_file = "ops/xla_jit_op",
+ include_internal_ops = 1,
+ out_ops_file = "ops/xla_jit_ops",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
name = "xla_jit_ops",
- srcs = ["ops/xla_jit_op.cc"],
- hdrs = ["ops/xla_jit_op.h"],
+ srcs = ["ops/xla_jit_ops.cc"],
+ hdrs = ["ops/xla_jit_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 922ae7c79a..027ca6d2d2 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter) {
- // Operators that don't look at the data of their inputs, just the shapes.
- const std::unordered_set<string> metadata_ops = {
- "Rank",
- "Shape",
- "ShapeN",
- "Size",
- };
-
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g,
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
- if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+ if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
+ return;
+ }
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index f792c52032..0362682bd6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -77,7 +79,10 @@ Status FunctionalizeControlFlowForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
- std::map<string, string>* canonicalized_name_to_new_name) {
+ std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
+ bool* modified) {
+ *modified = false;
+
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
@@ -89,7 +94,20 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
- const FunctionDef& fdef = body->fdef;
+ Graph* g = body->graph;
+
+ // Check if the graph has Switch or Merge node.
+ bool has_switch_or_merge = false;
+ for (Node* n : body->graph->nodes()) {
+ if (n->type_string() == "Switch" || n->type_string() == "Merge") {
+ has_switch_or_merge = true;
+ break;
+ }
+ }
+ // We cannot return here directly if the graph has no Switch/Merge.
+ // It might contain function call nodes, or If/While nodes with Switch/Merge
+ // in function body. We still need to rewrite those functions and modify
+ // corresponding nodes.
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -97,7 +115,7 @@ Status FunctionalizeControlFlowForFunction(
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : body->graph->nodes()) {
+ for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -108,57 +126,86 @@ Status FunctionalizeControlFlowForFunction(
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
- string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ string canonicalized_name =
+ Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
+ bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
- // If we already functionalized this function, skip functionalization
- // but still rewrite the node.
- new_name = iter->second;
+ // If we already processed this function, check if it was rewritten. If
+ // the function was rewritten, the entry will be non-empty. Otherwise
+ // the entry will be empty.
+ function_modified = iter->second.has_value();
+ if (function_modified) {
+ new_name = iter->second.value();
+ }
} else {
- new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ if (associated_function.type() ==
+ AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
+ // For SymbolicGradient, `name` is always "SymbolicGradient",
+ // which is not very informative. Use node name instead.
+ new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ }
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
- (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ name, new_name, associated_function.attrs(), fld, flr,
+ canonicalized_name_to_new_name, &function_modified));
+ if (function_modified) {
+ // If the function was rewritten, add an non-empty entry. So later we
+ // know we have processed this function, and it was rewritten into
+ // another function.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ } else {
+ // If the function was not rewritten, add an empty entry. So later
+ // we know we have processed this function, and it does not need to be
+ // rewritten.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
+ }
+ }
+ if (function_modified) {
+ *modified = true;
+
+ // Notice that if "n" is a function call, RewriteAssociatedFunction()
+ // will delete it and create a new node instead, making "n" an invalid
+ // pointer. That's fine because in that case, associated_functions will
+ // only have one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ g, n, fld, associated_function, new_name));
}
- // Notice that if "n" is a function call, RewriteAssociatedFunction() will
- // delete it and create a new node instead, making "n" an invalid pointer.
- // That's fine because in that case, associated_functions will only have
- // one member and the loop will only run once.
- TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
}
}
- // Functionalize the function body.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *body->graph, fld);
- }
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *body->graph, fld);
+ if (has_switch_or_merge) {
+ *modified = true;
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *g, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
+ fld);
+ }
}
- FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
-
- // Copy signature and ret from original FunctionDef.
- *functionalized_fdef.mutable_signature() = fdef.signature();
- *functionalized_fdef.mutable_ret() = fdef.ret();
- functionalized_fdef.mutable_signature()->set_name(new_func_name);
-
- // Add rewritten FunctionDef into library.
- if (func_name == new_func_name) {
- VLOG(2) << "Replacing function " << func_name;
+
+ if (*modified) {
+ // Add rewritten FunctionDef into library.
+ FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(new_func_name, functionalized_fdef));
- } else {
- VLOG(2) << "Adding function " << new_func_name;
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
}
return ret_status;
@@ -184,7 +231,7 @@ Status FunctionalizeControlFlowPass::Run(
{"TPUCompile", "function"},
{"XlaLaunch", "function"},
};
- std::map<string, string> canonicalized_name_to_new_name;
+ std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@@ -199,12 +246,15 @@ Status FunctionalizeControlFlowPass::Run(
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_f15n_"));
+ bool modified;
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
- &canonicalized_name_to_new_name));
- n->ClearAttr(func_attr);
- func.set_name(new_func_name);
- n->AddAttr(func_attr, func);
+ &canonicalized_name_to_new_name, &modified));
+ if (modified) {
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 46794f7b50..224e5ea123 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -62,6 +62,7 @@ tf_kernel_library(
"one_hot_op.cc",
"pack_op.cc",
"pad_op.cc",
+ "permute_op.cc",
"pooling_ops.cc",
"qr_op.cc",
"quantize_and_dequantize_op.cc",
@@ -94,6 +95,7 @@ tf_kernel_library(
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
+ "tensor_list_ops.cc",
"tile_ops.cc",
"topk_op.cc",
"training_ops.cc",
@@ -113,11 +115,13 @@ tf_kernel_library(
"shape_util.h",
],
deps = [
+ ":conv_op_helpers",
":if_op",
":while_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
+ "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
@@ -156,6 +160,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",
@@ -172,6 +177,27 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "conv_op_helpers",
+ srcs = ["conv_op_helpers.cc"],
+ hdrs = ["conv_op_helpers.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/kernels:bounds_check",
+ "//tensorflow/core/kernels:conv_ops",
+ "//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index b3ad0aea84..a267c0c72f 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 0d9a768a6f..47e517a657 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
@@ -65,7 +84,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// }
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);
}
@@ -84,12 +103,30 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
+}
+XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Div(x, y));
+}
+XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
auto trunc_mod = xla::Rem(x, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 4bd7c74dca..9bb11fb67e 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "absl/algorithm/container.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
@@ -37,60 +32,9 @@ class BroadcastToOp : public XlaOpKernel {
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
- OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
- errors::InvalidArgument(
- "Input rank (", input_shape.dims(),
- ") must be less than or equal to the output rank (",
- output_shape.dims(), ")"));
-
- auto input_dims = input_shape.dim_sizes();
- auto output_dims = output_shape.dim_sizes();
-
- // Broadcasting is done right-to-left on right-aligned dimensions; reverse
- // the two vectors so elements to be broadcast are aligned.
- absl::c_reverse(input_dims);
- absl::c_reverse(output_dims);
-
- std::vector<int64> broadcast_dims;
- std::vector<int64> broadcast_shape;
- for (int i = 0; i < output_shape.dims(); ++i) {
- if (i < input_shape.dims()) {
- OP_REQUIRES(
- context,
- (output_dims[i] == 0 && input_dims[i] == 0) ||
- (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
- errors::InvalidArgument("invalid shape to broadcast from ",
- input_shape.DebugString(), " to ",
- output_shape.DebugString()));
-
- broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
- broadcast_shape.push_back(output_dims[i]);
- }
- if (output_dims[i] != input_dims[i]) {
- // Add dimensions [I, O/I], which we will later flatten to just
- // [O]. We must do this in two phases since XLA broadcasting does not
- // support tiling.
- broadcast_shape.push_back(input_dims[i]);
- broadcast_shape.push_back(output_dims[i] / input_dims[i]);
- }
- } else {
- broadcast_shape.push_back(output_dims[i]);
- }
- }
- absl::c_reverse(broadcast_dims);
- int broadcast_shape_size = broadcast_shape.size();
- for (int64& broadcast_dim : broadcast_dims) {
- broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
- }
- absl::c_reverse(broadcast_shape);
- xla::XlaOp output = xla::Reshape(
- xla::BroadcastInDim(context->Input(0),
- xla::ShapeUtil::MakeShape(
- context->input_xla_type(0), broadcast_shape),
- broadcast_dims),
- output_shape.dim_sizes());
- context->SetOutput(0, output);
+ auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+ OP_REQUIRES_OK(context, output.status());
+ context->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index da8cf3fc6f..2628ef8e24 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel {
return;
}
break;
+ case DT_COMPLEX64:
+ if (proto_.scomplex_val_size() == 2) {
+ ctx->SetOutput(
+ 0,
+ xla::Broadcast(xla::ConstantR0<xla::complex64>(
+ b, xla::complex64(proto_.scomplex_val(0),
+ proto_.scomplex_val(1))),
+ shape.dim_sizes()));
+ return;
+ }
+ break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
ctx->SetOutput(
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000000..c9a1be4940
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+ int num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2); // Crash OK
+ xla::Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ num_dims - 1,
+ shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+ return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+ xla::XlaBuilder* builder) {
+ xla::Shape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+ int64 input_feature =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ std::vector<int64> expanded_feature_broadcast_dims(
+ expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+ int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+ int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+ int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ xla::Shape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dimensions(
+ output_feature_dim, depthwise_multiplier * input_feature);
+ return xla::Reshape(
+ filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
+ auto masked_expanded_filter =
+ xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ filter_backprop, xla::ZerosLike(filter_backprop));
+
+ auto elem_type = filter_shape.element_type();
+ return xla::Reshape(
+ // This reduce does not need inputs to be converted with
+ // XlaHelpers::SumAccumulationType() since the select above guarantees
+ // that only one element is non zero, so there cannot be accumulated
+ // precision error.
+ xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+ CreateScalarAddComputation(elem_type, builder),
+ {filter_shape.dimensions_size() - 2}),
+ xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+ const int num_dims = attrs.num_spatial_dims + 2;
+ if (attrs.strides.size() != num_dims) {
+ return errors::InvalidArgument("Sliding window strides field must specify ",
+ num_dims, " dimensions");
+ }
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+ if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not yet support strides in the batch and "
+ "depth dimensions.");
+ }
+ if (attrs.dilations.size() != num_dims) {
+ return errors::InvalidArgument("Dilations field must specify ", num_dims,
+ " dimensions");
+ }
+ if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not support dilations in the batch and "
+ "depth dimensions.");
+ }
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ if (attrs.dilations[input_dim] < 1) {
+ return errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ attrs.dilations[input_dim]);
+ }
+ }
+ return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+ StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+ const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+ absl::Span<const int32> dilations, const std::vector<int32>& strides,
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ TensorShape input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+ return ConvBackpropComputeDimensionsV2(
+ label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape, dilations, strides, padding, data_format,
+ dims);
+}
+
+} // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+ bool depthwise,
+ OpKernelConstruction* ctx) {
+ ConvOpAttrs attrs;
+ attrs.num_spatial_dims = num_spatial_dims;
+ attrs.depthwise = depthwise;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+ string data_format;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+ if (!FormatFromString(data_format, &attrs.data_format)) {
+ return errors::InvalidArgument("Invalid data format: ", data_format);
+ }
+
+ return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = conv_input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+ // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+ // For 2D convolution, there should be 4 dimensions.
+ int num_dims = attrs.num_spatial_dims + 2;
+ if (input_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString());
+ }
+ if (filter_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument(
+ "filter must be ", num_dims,
+ "-dimensional: ", filter_shape.DebugString());
+ }
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ if (in_depth != input_shape.dimensions(feature_dim)) {
+ return errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth, " vs ",
+ input_shape.dimensions(feature_dim));
+ }
+
+ if (attrs.depthwise) {
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+ }
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = attrs.strides.at(dim);
+ rhs_dilation[i] = attrs.dilations.at(dim);
+
+ int64 unused_output_size;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+ input_shape.dimensions(dim), filter_shape.dimensions(i),
+ rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ return xla::ConvGeneralDilated(
+ conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+ dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ int num_dims = attrs.num_spatial_dims + 2;
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ auto* builder = filter.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(out_backprop));
+
+ xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+ out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+ attrs.data_format, &dims));
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(batch_dim);
+ dnums.set_output_batch_dimension(batch_dim);
+ dnums.set_input_feature_dimension(feature_dim);
+ dnums.set_output_feature_dimension(feature_dim);
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+ dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+ std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
+
+ kernel_spatial_dims[i] = i;
+ padding[i] = {dims.spatial_dims[i].pad_before,
+ dims.spatial_dims[i].pad_after};
+ lhs_dilation[i] = dims.spatial_dims[i].stride;
+ rhs_dilation[i] = attrs.dilations[dim];
+ }
+
+ // Mirror the filter in the spatial dimensions.
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ return xla::ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+ filter_shape.dimensions(attrs.num_spatial_dims + 1)
+ : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = activations.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+ builder->GetShape(activations));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(gradients));
+ const xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int num_dims = attrs.num_spatial_dims + 2;
+ int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+ dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int64 padded_in_size =
+ dims.spatial_dims[i].expanded_output_size +
+ (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int64 pad_before =
+ attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+ padding[i] = {pad_before, pad_total - pad_before};
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
+ }
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop =
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
+
+ return filter_backprop;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
new file mode 100644
index 0000000000..6e1b70a478
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+// This header exposes utilities for translating TensorFlow convolution ops into
+// XLA ops.
+//
+// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
+// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
+// this header to implement a new and exciting convolution op, for example a
+// fused TensorFlow op that contains a convolution and other things.
+
+namespace tensorflow {
+
+// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
+// convolution.
+struct ConvOpAttrs {
+ // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
+ static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
+ OpKernelConstruction* ctx);
+
+ bool depthwise;
+ int num_spatial_dims;
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Creates a new XLA forward or backward convolution with the given inputs and
+// attributes.
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 674720e22f..cd7c820be0 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -15,12 +15,17 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,250 +38,28 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
-
namespace {
-// Returns the expanded size of a filter used for depthwise convolution.
-// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
-TensorShape ExpandedFilterShapeForDepthwiseConvolution(
- const TensorShape& shape) {
- int num_dims = shape.dims();
- CHECK_GE(num_dims, 2);
- TensorShape expanded_shape = shape;
- expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
- shape.dim_size(num_dims - 1));
- return expanded_shape;
-}
-
-// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
-}
-
-// Create a mask for depthwise convolution that will make a normal convolution
-// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
-// depthwise filter this returns a [2, 2, 3, 6] tensor
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// The first step is to create a one tensor, A, that is [3]
-// 0 1 2
-//
-// and another tensor, B, that is [3 * 2]
-// 0 1 2 3 4 5
-//
-// and divide B it by 2 to get
-// 0 0 1 1 2 2
-//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
-xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
-
- // Create a M sized linspace and an M*N sized linspace that will be
- // broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
- xla::XlaOp expanded_feature_iota =
- xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
-
- // Divide the M*N sized linspace by the depthwise_multiplier to create
- // [0 0 1 1 2 2] in the example in the function comment.
- expanded_feature_iota =
- xla::Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
-
- // Broadcast the N*M linspace to [H, W, ..., M, M*N].
- auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
- expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota =
- xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
- // Compare the broadcasted linspace to the input feature linspace in the
- // input feature dimension to create a diagonal predicate.
- return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
-}
-
-// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
-// build a depthwise convolution.
-xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- const xla::XlaOp& filter) {
- int64 input_feature_dim = filter_shape.dims() - 2;
- int64 output_feature_dim = filter_shape.dims() - 1;
- int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
- int64 input_feature = filter_shape.dim_size(input_feature_dim);
-
- // Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = filter_shape;
- implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
- implicit_broadcast_filter_shape.set_dim(output_feature_dim,
- depthwise_multiplier * input_feature);
- return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-}
-
-// Reduces the results of the convolution with an expanded filter to the
-// non-expanded filter.
-xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
- const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter_backprop,
- xla::XlaBuilder* builder) {
- auto masked_expanded_filter = xla::Select(
- CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
- CreateExpandedZero(filter_shape, dtype, builder));
- return xla::Reshape(
- // This reduce does not need inputs to be converted with
- // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
- // ExpandedZero guarantees that only one element is non zero, so there
- // cannot be accumulated precision error.
- xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
- filter_shape.dim_sizes());
-}
-
class ConvOp : public XlaOpKernel {
public:
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape input_shape = ctx->InputShape(0);
- // Input filter is of the following dimensions:
- // [ filter_rows, filter_cols, ..., in_depth, out_depth]
- const TensorShape filter_shape = ctx->InputShape(1);
-
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(
- ctx, input_shape.dims() == num_dims(),
- errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
- input_shape.DebugString()));
- OP_REQUIRES(
- ctx, filter_shape.dims() == num_dims(),
- errors::InvalidArgument("filter must be ", num_dims(),
- "-dimensional: ", filter_shape.DebugString()));
-
- // The last two dimension of the filter are the input and output shapes.
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
-
- // The 'C' dimension for input is in_depth. It must be the same as
- // the filter's in_depth.
- OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", input_shape.dim_size(feature_dim)));
-
- xla::XlaOp filter = ctx->Input(1);
- if (depthwise_) {
- filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
- }
-
- xla::ConvolutionDimensionNumbers dims;
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-
- dims.set_input_batch_dimension(batch_dim);
- dims.set_output_batch_dimension(batch_dim);
- dims.set_input_feature_dimension(feature_dim);
- dims.set_output_feature_dimension(feature_dim);
- dims.set_kernel_input_feature_dimension(num_spatial_dims_);
- dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dims.add_input_spatial_dimensions(dim);
- dims.add_kernel_spatial_dimensions(i);
- dims.add_output_spatial_dimensions(dim);
- window_strides[i] = strides_.at(dim);
- rhs_dilation[i] = dilations_.at(dim);
-
- int64 unused_output_size;
- OP_REQUIRES_OK(
- ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), filter_shape.dim_size(i),
- rhs_dilation[i], window_strides[i], padding_,
- &unused_output_size, &padding[i].first, &padding[i].second));
- }
-
- xla::XlaOp conv = xla::ConvGeneralDilated(
- ctx->Input(0), filter, window_strides, padding, lhs_dilation,
- rhs_dilation, dims,
- /*feature_group_count=*/depthwise_ ? in_depth : 1);
- ctx->SetOutput(0, conv);
+ xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
+ OP_REQUIRES_OK(ctx, conv.status());
+ ctx->SetOutput(0, conv.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
@@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel {
public:
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- TensorShape input_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
-
- const TensorShape filter_shape = ctx->InputShape(1);
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, input_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- auto filter = ctx->Input(1);
- auto out_backprop = ctx->Input(2);
-
- // The input gradients are computed by a convolution of the output
- // gradients and the filter, with some appropriate padding. See the
- // comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(batch_dim);
- dnums.set_output_batch_dimension(batch_dim);
- dnums.set_input_feature_dimension(feature_dim);
- dnums.set_output_feature_dimension(feature_dim);
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- // Transpose the input and output features for computing the gradient.
- dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
- dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
-
- std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(i);
- dnums.add_output_spatial_dimensions(dim);
-
- kernel_spatial_dims[i] = i;
- padding[i] = {dims.spatial_dims[i].pad_before,
- dims.spatial_dims[i].pad_after};
- lhs_dilation[i] = dims.spatial_dims[i].stride;
- rhs_dilation[i] = dilations_[dim];
- }
-
- // Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
-
- // activation gradients
- // = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = xla::ConvGeneralDilated(
- out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums,
- /*feature_group_count=*/
- depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
- filter_shape.dim_size(num_spatial_dims_ + 1)
- : 1);
-
- ctx->SetOutput(0, in_backprop);
+ TensorShape input_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+ xla::Shape input_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> in_backprop =
+ MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
+ ctx->Input(1), ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, in_backprop.status());
+ ctx->SetOutput(0, in_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
@@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel {
public:
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-
- OP_REQUIRES(
- ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape activations_shape = ctx->InputShape(0);
- TensorShape filter_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
-
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, activations_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp activations = ctx->Input(0);
- xla::XlaOp gradients = ctx->Input(2);
-
- // The filter gradients are computed by a convolution of the input
- // activations and the output gradients, with some appropriate padding.
- // See the comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
-
- // The activations (inputs) form the LHS of the convolution.
- // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
- // For the gradient computation, we flip the roles of the batch and
- // feature dimensions.
- // Each spatial entry has size in_depth * batch
-
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
-
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
-
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
-
- // Tensorflow filter shape is [ H, W, ..., inC, outC ].
- for (int i = 0; i < num_spatial_dims_; ++i) {
- dnums.add_output_spatial_dimensions(i);
- }
- dnums.set_output_batch_dimension(num_spatial_dims_);
- dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(dim);
-
- // We will also need to pad the input with zeros such that after the
- // convolution, we get the right size for the filter.
- // The padded_in_rows should be such that when we convolve this with the
- // expanded_out_rows as a filter, we should get filter_rows back.
- //
- const int64 padded_in_size =
- dims.spatial_dims[i].expanded_output_size +
- (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
-
- // However it can be smaller than input_rows: in this
- // case it means some of the inputs are not used.
- //
- // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
- //
- // INPUT = [ A B C ]
- //
- // FILTER = [ x y ]
- //
- // and the output will only have one column: a = A * x + B * y
- //
- // and input "C" is not used at all.
- //
- // We apply negative padding in this case.
- const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
-
- // + For the VALID padding, we don't pad anything on the top/left side
- // and pad the bottom/right side with the remaining space.
- // + For the SAME padding, we pad top/left side the same as bottom/right
- // side.
- //
- // In addition, if the padded input size is smaller than the input size,
- // we need to ignore some training elements of the input. We do this by
- // applying negative padding on the right/bottom.
- const int64 pad_before =
- padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
- padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = dilations_[dim];
- }
-
- // Besides padding the input, we will also expand output_rows to
- // expanded_out_rows = (output_rows - 1) * stride + 1
- // with zeros in between:
- //
- // a . . . b . . . c . . . d . . . e
- //
- // This is done by specifying the window dilation factors in the
- // convolution HLO below.
- auto filter_backprop =
- xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
- if (depthwise_) {
- filter_backprop = ContractFilterForDepthwiseBackprop(
- ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
- }
- ctx->SetOutput(0, filter_backprop);
+ TensorShape filter_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+ xla::Shape filter_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
+ ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, filter_backprop.status());
+ ctx->SetOutput(0, filter_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 6653944a91..516ead4bfe 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel {
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 33a73fe5fd..921b4340c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ OP_REQUIRES(context, output_size <= kint32max,
+ errors::InvalidArgument("Need output_size <= kint32Max, got ",
+ output_size));
xla::XlaOp score_thresh = context->Input("score_threshold");
xla::XlaOp iou_thresh = context->Input("iou_threshold");
@@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel {
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
- // num_valid is scalar.
- xla::XlaOp num_valid = xla::Reduce(
+ // num_valid is scalar. Value should be bound by output_size.
+ xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
+ xla::XlaOp num_valid =
+ xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257b70..7b2bb4a7c5 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
- return kernel;
+ return xla::ConstantR1<float>(builder, kernel);
}
// Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+ auto depthwise_kernel = xla::Broadcast(
+ xla::Zero(builder, xla::F32),
+ {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
- auto diag = xla::ConvertElementType(
- xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
- 2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
return xla::Mul(
- xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
/*broadcast_dimensions=*/{1}),
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ Make1DKernel(builder, kernel_size[0]),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
- auto diag = xla::ConvertElementType(
- xla::Eq(
- xla::Broadcast(channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
- if (dim == 1) {
- return xla::Mul(
- diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1});
- }
- return xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ auto depthwise_kernel =
+ xla::Broadcast(xla::Zero(builder, xla::F32),
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+ return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+ /*broadcast_dimensions=*/{dim});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(1 + i);
dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, upper_padding[0]},
{dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(1 + i);
- dimension_numbers.add_output_spatial_dimensions(1 + i);
+ dimension_numbers.add_input_spatial_dimensions(i + 1);
+ dimension_numbers.add_output_spatial_dimensions(i + 1);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
@@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
new file mode 100644
index 0000000000..0764e5503d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+class DataFormatVecPermuteOp : public XlaOpKernel {
+ public:
+ explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
+ OP_REQUIRES(
+ ctx, src_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ TensorFormat data_format;
+ OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
+ OP_REQUIRES(
+ ctx, dst_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ const TensorShape input_tensor_shape = ctx->InputShape(0);
+ int input_rank = input_tensor_shape.dims();
+ OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2,
+ errors::InvalidArgument(
+ "Input must be a vector or matrix, but got shape ",
+ input_tensor_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(0) == 4,
+ errors::InvalidArgument(
+ "First dimension of input must be of size 4, but got shape ",
+ input_tensor_shape.DebugString()));
+ if (input_rank == 2) {
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "Second dimension of 2D input must be of size 2, but got shape ",
+ input_tensor_shape.DebugString()));
+ }
+ std::vector<int32> dst_indices(4, 0);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (src_format_[i] == dst_format_[j]) {
+ dst_indices[i] = j;
+ break;
+ }
+ }
+ }
+ auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
+ if (input_rank == 2) {
+ keys = xla::BroadcastInDim(
+ keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0});
+ }
+ auto sorted = xla::Sort(keys, ctx->Input(0), 0);
+ auto output = xla::GetTupleElement(sorted, 1);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ string src_format_;
+ string dst_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp);
+};
+
+// TODO(b/115384656): Support DT_INT64.
+REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32),
+ DataFormatVecPermuteOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 8102faad28..8eee5b1299 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel {
std::vector<int64> window_dimensions;
std::vector<int64> window_strides;
+ std::vector<int64> base_dilations;
+ std::vector<int64> window_dilations;
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
"window_dimensions", &window_dimensions));
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
&window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations",
+ &base_dilations));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dilations", &window_dilations));
const int rank = input_shape.dims();
OP_REQUIRES(context, rank == window_dimensions.size(),
@@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel {
"The size of window_strides must be equal to the input "
"rank (",
window_strides.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == base_dilations.size(),
+ errors::InvalidArgument(
+ "The size of base_dilations must be equal to the input "
+ "rank (",
+ base_dilations.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_dilations.size(),
+ errors::InvalidArgument(
+ "The size of window_dilations must be equal to the input "
+ "rank (",
+ window_dilations.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel {
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), *reducer.computation,
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
context->SetOutput(0, output);
}
@@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaReduceWindow")
.CompileTimeConstInput("window_dimensions")
.CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("base_dilations")
+ .CompileTimeConstInput("window_dilations")
.CompileTimeConstInput("padding"),
ReduceWindowOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index ab094d7dd1..57afd608de 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel {
}
auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
- *reducer, window_dims, window_strides, padding);
+ *reducer, window_dims, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding);
output =
XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 25a5bcbe1d..0c32b8def0 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
// The type-specific part of the implementation of Range.
template <typename T>
-Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
- const xla::LiteralSlice& limit_literal,
- const xla::LiteralSlice& delta_literal,
- Tensor* output) {
+xla::StatusOr<xla::XlaOp> CreateRangeTensor(
+ const xla::LiteralSlice& start_literal,
+ const xla::LiteralSlice& limit_literal,
+ const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) {
T start = start_literal.Get<T>({});
T limit = limit_literal.Get<T>({});
T delta = delta_literal.Get<T>({});
@@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
- *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size}));
- auto flat = output->flat<T>();
- T val = start;
- for (int64 i = 0; i < size; ++i) {
- flat(i) = val;
- val += delta;
- }
- return Status::OK();
+ return xla::ConstantR0(builder, start) +
+ xla::ConstantR0(builder, delta) *
+ xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(),
+ size);
}
class RangeOp : public XlaOpKernel {
@@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
DataType type = input_type(0);
- Tensor output;
- Status status;
+ xla::StatusOr<xla::XlaOp> output;
switch (type) {
case DT_INT32:
- status = CreateRangeTensor<int32>(start, limit, delta, &output);
+ output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder());
break;
case DT_INT64:
- status = CreateRangeTensor<int64>(start, limit, delta, &output);
+ output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder());
break;
case DT_FLOAT:
- status = CreateRangeTensor<float>(start, limit, delta, &output);
+ output = CreateRangeTensor<float>(start, limit, delta, ctx->builder());
break;
case DT_DOUBLE:
- status = CreateRangeTensor<double>(start, limit, delta, &output);
+ output = CreateRangeTensor<double>(start, limit, delta, ctx->builder());
break;
default:
- status = errors::InvalidArgument("Invalid type for Range ",
+ output = errors::InvalidArgument("Invalid type for Range ",
DataTypeString(type));
}
- OP_REQUIRES_OK(ctx, status);
- ctx->SetConstantOutput(0, output);
+ OP_REQUIRES_OK(ctx, output.status());
+ ctx->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 2e0a69b70e..c8a0f31a03 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index aaeeae01cc..45f03d8c21 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel {
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- context->SetOutput(0, xla::Sort(context->Input(0)));
+ context->SetOutput(0, xla::Sort(context->Input("input")));
}
};
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
+class XlaKeyValueSortOp : public XlaOpKernel {
+ public:
+ explicit XlaKeyValueSortOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp result =
+ xla::Sort(context->Input("keys"), context->Input("values"));
+ context->SetOutput(0, xla::GetTupleElement(result, 0));
+ context->SetOutput(1, xla::GetTupleElement(result, 1));
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
new file mode 100644
index 0000000000..74d4fcc425
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -0,0 +1,226 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorList operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
+ TensorShape* tensor_list_shape) {
+ auto shape_or_status = builder->GetShape(op);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ xla::Shape shape = shape_or_status.ValueOrDie();
+ TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
+ return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
+ tensor_list_shape);
+}
+
+class TensorListReserveOp : public XlaOpKernel {
+ public:
+ explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape element_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
+ int64 num_elements;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+
+ TensorShape tensor_shape;
+ tensor_shape.AddDim(num_elements);
+ tensor_shape.AppendShape(element_shape);
+
+ xla::XlaBuilder* b = ctx->builder();
+ ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
+ tensor_shape.dim_sizes()),
+ xla::ConstantR0<int32>(b, 0)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListReserve")
+ .CompileTimeConstInput("element_shape")
+ .CompileTimeConstInput("num_elements"),
+ TensorListReserveOp);
+
+class EmptyTensorListOp : public XlaOpKernel {
+ public:
+ explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->CtxFailure(
+ errors::InvalidArgument("XLA compilation requires a fixed tensor list "
+ "size. Use TensorListReserve instead."));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
+};
+
+REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp);
+
+class TensorListElementShapeOp : public XlaOpKernel {
+ public:
+ explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
+ shape.RemoveDim(0);
+
+ switch (shape_type_) {
+ case DT_INT64:
+ ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes()));
+ break;
+ case DT_INT32: {
+ std::vector<int32> size;
+ for (int64 s : shape.dim_sizes()) {
+ size.push_back(s);
+ }
+ ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
+ break;
+ }
+ default:
+ ctx->CtxFailure(
+ errors::InvalidArgument("Unsupported shape type requested"));
+ return;
+ }
+ }
+
+ private:
+ DataType shape_type_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp);
+
+class TensorListPushBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp list = ctx->Input(0);
+ TensorShape elem_shape = ctx->InputShape(1);
+
+ xla::XlaOp ta = xla::GetTupleElement(list, 0);
+ xla::XlaOp index = xla::GetTupleElement(list, 1);
+ xla::XlaOp value = ctx->Input(1);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ ctx->SetOutput(
+ 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
+ index + xla::ConstantR0<int32>(b, 1)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp);
+
+class TensorListPopBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp state = ctx->Input(0);
+
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
+
+ xla::XlaOp ta = xla::GetTupleElement(state, 0);
+ xla::XlaOp index = xla::GetTupleElement(state, 1);
+
+ index = index - xla::ConstantR0<int32>(b, 1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}}));
+
+ auto slice_shape = shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+
+ ctx->SetOutput(0, xla::Tuple(b, {ta, index}));
+ ctx->SetOutput(1, xla::Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 8597e7f139..1ce3930fd1 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -32,6 +32,22 @@ cc_library(
)
cc_library(
+ name = "broadcast",
+ srcs = ["broadcast.cc"],
+ hdrs = ["broadcast.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc
new file mode 100644
index 0000000000..3e402ef855
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace tensorflow {
+
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims) {
+ xla::XlaBuilder* builder = input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ absl::Span<int64 const> input_dims =
+ xla::AsInt64Slice(input_shape.dimensions());
+
+ if (input_dims == output_dims) {
+ return input;
+ }
+
+ if (input_dims.size() > output_dims.size()) {
+ return errors::InvalidArgument(
+ "Input shape (", xla::ShapeUtil::HumanString(input_shape),
+ ") must have rank less than or equal to the output shape [",
+ absl::StrJoin(output_dims, ","), "]");
+ }
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ auto input_it = input_dims.rbegin();
+ for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend();
+ ++output_it) {
+ if (input_it != input_dims.rend()) {
+ if (!(*output_it == 0 && *input_it == 0) &&
+ !(*input_it != 0 && *output_it % *input_it == 0)) {
+ return errors::InvalidArgument("Invalid shape broadcast from ",
+ xla::ShapeUtil::HumanString(input_shape),
+ " to [", absl::StrJoin(output_dims, ","),
+ "]");
+ }
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (*output_it == *input_it) {
+ broadcast_shape.push_back(*output_it);
+ } else if (*output_it != *input_it) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(*input_it);
+ broadcast_shape.push_back(*output_it / *input_it);
+ }
+ ++input_it;
+ } else {
+ broadcast_shape.push_back(*output_it);
+ }
+ }
+ TF_RET_CHECK(input_it == input_dims.rend());
+
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::BroadcastInDim(
+ input,
+ xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape),
+ broadcast_dims);
+ if (broadcast_shape != output_dims) {
+ output = xla::Reshape(output, output_dims);
+ }
+ return output;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/tf2xla/lib/broadcast.h
index 35b4b4e20b..591e696f06 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.h
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
-namespace xla {
-namespace gpu {
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
- return !config.debug_options().xla_backend_extra_options().count(
- "xla_gpu_experimental_conv_disable_layout_heuristic");
-}
+namespace tensorflow {
-} // namespace gpu
-} // namespace xla
+// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting
+// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling.
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 38dfde165d..2b1c2ced92 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -38,12 +38,10 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
combiner,
xla::XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
- TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
- absl::Span<const int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -81,104 +79,129 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
}
}
- // Shape of the non-indexed dimensions of the buffer.
- std::vector<int64> buffer_shape_post_axes(
- buffer_dims.begin() + num_index_dims, buffer_dims.end());
-
- // Flatten the major dimensions of indices and updates into a single dimension
- // for ease of iteration.
- std::vector<int64> flat_indices_shape({num_indices});
- if (indices_are_vectors) {
- flat_indices_shape.push_back(num_index_dims);
+ // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
+ // shape [3,3]:
+ // NOTE: ***This case will not be generated by any of the tf.scatter ops.***
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[3,2] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={0},
+ // inserted_window_dims={1},
+ // scatter_dims_to_operand_dims={1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
+ // shape [3,3]:
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[2,3] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ //
+ //
+ // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
+ // shape [3,3,2]
+ //
+ // operand = s32[3,3,2] parameter(0)
+ // indices = s32[2,2] parameter(1)
+ // updates = s32[2,2] parameter(2)
+ // scatter = s32[3,3,2] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0,1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
+ //
+ // operand = s32[1,1] parameter(0)
+ // indices = s32[1] parameter(1)
+ // updates = s32[1] parameter(2)
+ // scatter = s32[1,1] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ // Note that updates operand would be broadcasted into [1] in this case.
+ //
+
+ xla::ScatterDimensionNumbers dim_numbers;
+ dim_numbers.set_index_vector_dim(indices_are_vectors
+ ? indices_shape.dimensions_size() - 1
+ : indices_shape.dimensions_size());
+
+ int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
+ int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
+ int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
+
+ // If the rank of `updates` is 0 and does not match the expected rank of
+ // updates, broadcast `updates` to the expected shape of updates.
+ auto new_updates = updates;
+ std::vector<int64> expected_updates_dims(indices_dims.begin(),
+ indices_dims.end());
+ for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
+ expected_updates_dims.push_back(buffer_shape.dimensions(dim));
+ }
+ int64 expected_updates_rank = expected_updates_dims.size();
+ if (updates_rank == 0 && expected_updates_rank != 0) {
+ new_updates = xla::Broadcast(updates, expected_updates_dims);
+ TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
+ updates_rank = xla::ShapeUtil::Rank(updates_shape);
}
- std::vector<int64> flat_updates_shape({num_indices});
- flat_updates_shape.insert(flat_updates_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
-
- // Construct the initial values of the loop-carried Tensors.
- auto flat_indices = xla::Reshape(indices, flat_indices_shape);
- auto flat_updates = xla::Reshape(updates, flat_updates_shape);
- auto init = {flat_indices, flat_updates, buffer};
-
- // Constructs the loop body. The implementation of scatter is essentially:
- // for i in range(num_indices):
- // index = dynamic-slice(indices, i)
- // update = dynamic-slice(updates, i)
- // buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
- xla::XlaBuilder* body_builder) {
- auto indices = loop_vars[0];
- auto updates = loop_vars[1];
- auto buffer = loop_vars[2];
-
- auto zero_index = xla::ConstantLiteral(
- body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
-
- // Slice the i-th index from the indices array.
- xla::XlaOp index;
- auto indices_offset = xla::Reshape(i, {1});
- if (indices_are_vectors) {
- indices_offset = xla::Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
-
- index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
- index = xla::Collapse(index, {0, 1});
- } else {
- index = xla::DynamicSlice(indices, indices_offset, {1});
+ if (updates_rank > 0) {
+ for (int64 i = (updates_rank - num_window_dims_in_updates);
+ i < updates_rank; ++i) {
+ dim_numbers.add_update_window_dims(i);
}
+ }
- // Discard updates with negative indices, since some users expect this.
- auto index_in_range = xla::ReduceAll(
- xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
- xla::CreateScalarAndComputation(xla::PRED, body_builder));
-
- // Make the index in bounds to prevent implementation defined behavior.
- index = xla::Max(index, zero_index);
- index = xla::Pad(
- index, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
-
- // Slice the i-th index from the updates array.
- auto updates_offset = xla::Reshape(i, {1});
- updates_offset = xla::Pad(
- updates_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
- std::vector<int64> flat_updates_slice_shape({1});
- flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- auto update =
- xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
-
- // Unflatten the major (iteration) dimensions of the slice to their
- // original shape.
- std::vector<int64> updates_slice_shape(num_index_dims, 1);
- updates_slice_shape.insert(updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- update = xla::Reshape(update, updates_slice_shape);
-
- // Apply the update to the buffer. If there is a combiner, use it to merge
- // the current values with the update.
- auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
+ for (int64 i = 0; i < num_index_dims; ++i) {
+ dim_numbers.add_inserted_window_dims(i);
+ dim_numbers.add_scatter_dims_to_operand_dims(i);
+ }
+
+ // Build the combiner computation.
+ xla::XlaComputation combiner_computation;
+ {
+ xla::XlaBuilder cb("scatter-combiner");
+ auto xla_scalar_shape =
+ xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
+ auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
+ auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
if (combiner) {
- update = combiner(current_value, update, body_builder);
+ combiner(p0, p1, &cb);
}
- // Use the current value instead of the update if the index is out of
- // bounds.
- update = xla::Select(index_in_range, update, current_value);
- // Apply the update.
- buffer = xla::DynamicUpdateSlice(buffer, update, index);
-
- return std::vector<xla::XlaOp>{indices, updates, buffer};
- };
-
- TF_ASSIGN_OR_RETURN(auto outputs,
- XlaForEachIndex(num_indices, indices_shape.element_type(),
- body_fn, init, "scatter", builder));
- return outputs[2];
+ combiner_computation = cb.Build().ConsumeValueOrDie();
+ }
+
+ VLOG(3) << "Scatter op:";
+ VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape);
+ VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape);
+ VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape);
+ VLOG(3) << " Scatter Dimension Numbers: ";
+ VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
+ VLOG(3) << " update_window_dims: ["
+ << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
+ VLOG(3) << " inserted_window_dims: ["
+ << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
+ VLOG(3) << " scatter_dims_to_operand_dims: ["
+ << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
+ << "]";
+
+ return xla::Scatter(buffer, indices, new_updates, combiner_computation,
+ dim_numbers);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 13a5f1b850..4cf478c4b9 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -34,7 +34,11 @@ namespace tensorflow {
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
// minor dimension of `indices` represents a vector of indices.
//
-// If any indices are negative, the corresponding update is discarded.
+// If `updates` is a scalar, then it will be broadcasted into the expected shape
+// of updates.
+//
+// If any part of the update region is out-of-bounds, the corresponding update
+// is discarded.
//
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 02363500ef..bd2c0a5ee8 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -121,8 +121,8 @@ Wraps the XLA DynamicSlice operator, documented at
DynamicSlice extracts a sub-array from the input array at dynamic
start_indices. The size of the slice in each dimension is passed in
size_indices, which specify the end point of exclusive slice intervals in each
-dimension -- [start, start + size). The shape of start_indices must be rank ==
-1, with dimension size equal to the rank of operand.
+dimension -- [start, start + size). The shape of start_indices must have rank 1,
+with dimension size equal to the rank of operand.
input: A `Tensor` of type T.
@@ -131,7 +131,8 @@ start_indices: Rank 1 tensor of N integers containing the starting indices of
start_indices: List of N integers containing the slice size for each
dimension. Each value must be strictly greater than zero, and start + size
- must be less
+ must be less than or equal to the size of the dimension to avoid
+ implementation defined behavior.
)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
@@ -282,6 +283,8 @@ REGISTER_OP("XlaReduceWindow")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
+ .Input("base_dilations: Tindices")
+ .Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
@@ -353,12 +356,33 @@ Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
-Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
+Sorts a tensor. Currently only sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
+REGISTER_OP("XlaKeyValueSort")
+ .Input("keys: K")
+ .Input("values: V")
+ .Output("sorted_keys: K")
+ .Output("sorted_values: V")
+ .Attr("K: realnumbertype")
+ .Attr("V: type")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA Sort operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#sort
+.
+
+Sorts a tensor. Currently only sorts in ascending order are supported.
+
+keys: A `Tensor` of type K.
+values: A `Tensor` of type V.
+sorted_keys: A `Tensor` of type K.
+sorted_values: A `Tensor` of type V.
+)doc");
+
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 27dd18a9bb..5e86b5d8ec 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast
def broadcast(x, dims, name=None):
x = ops.convert_to_tensor(x)
- shape = array_ops.concat(
- [constant_op.constant(dims),
- array_ops.shape(x)], axis=0)
+ shape = array_ops.concat([constant_op.constant(dims),
+ array_ops.shape(x)],
+ axis=0)
return array_ops.broadcast_to(x, shape, name=name)
@@ -320,6 +320,8 @@ def reduce_window(operand,
reducer,
window_dimensions,
window_strides=None,
+ base_dilations=None,
+ window_dilations=None,
padding=None,
name=None):
"""Wraps the XLA ReduceWindow operator.
@@ -332,22 +334,27 @@ def reduce_window(operand,
init: a scalar tensor representing the initial value for the reduction
reducer: a reduction function that combines a pair of scalars.
window_dimensions: shape of the window, as a list of integers
- window_strides: inter-window strides, as a list of integers. Optional;
- if omitted, defaults to strides of 1.
+ window_strides: inter-window strides, as a list of integers. Optional; if
+ omitted, defaults to strides of 1.
padding: padding to apply to 'operand'. List of (low, high) pairs of
integers that specify the padding to apply before and after each
dimension. Optional; if omitted, defaults to no padding.
name: the operator name, or None.
+
Returns:
A tensor that represents the output of the reduce_window operator.
"""
window_strides = window_strides or [1] * len(window_dimensions)
+ base_dilations = base_dilations or [1] * len(window_dimensions)
+ window_dilations = window_dilations or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
return gen_xla_ops.xla_reduce_window(
input=operand,
init_value=init,
window_dimensions=window_dimensions,
window_strides=window_strides,
+ base_dilations=base_dilations,
+ window_dilations=window_dilations,
padding=padding,
computation=reducer,
name=name)
@@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides):
sort = gen_xla_ops.xla_sort
+key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 20f2ce2919..72b240996f 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "absl/container/flat_hash_map.h"
namespace tensorflow {
/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
@@ -30,9 +30,9 @@ namespace tensorflow {
}
}
-static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
+static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap() {
- auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
+ auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
XlaResourceKind resource_kind) {
@@ -103,15 +103,15 @@ CreateResourceOpInfoMap() {
return result;
}
-static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
+static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap() {
- static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
- CreateResourceOpInfoMap();
+ static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
+ op_info_map = CreateResourceOpInfoMap();
return *op_info_map;
}
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
- const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
+ const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
GetStaticResourceOpInfoMap();
auto it = op_infos.find(op);
return it == op_infos.end() ? nullptr : &it->second;
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
index a85ef040a7..956f597301 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) {
}
TEST(ResourceOperationTableTest, HaveAllResourceOps) {
- gtl::FlatMap<string, bool> known_resource_ops;
+ absl::flat_hash_map<string, bool> known_resource_ops;
for (absl::string_view known_resource_op :
resource_op_table_internal::GetKnownResourceOps()) {
ASSERT_TRUE(
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index 9d1992205b..b589512dcd 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape) {
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+ *shape = TensorShapeToXLAShape(type, tensor_shape);
+ return Status::OK();
+}
+
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape) {
int rank = tensor_shape.dims();
std::vector<int64> dimensions(rank);
std::vector<int64> layout(rank);
@@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
- xla::PrimitiveType type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
-
- *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
- return Status::OK();
+ return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
index 58240b9c96..f7e34a5b40 100644
--- a/tensorflow/compiler/tf2xla/shape_util.h
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape);
+// Converts a TensorShape into the equivalent XLA Shape proto, taking an
+// xla::PrimitiveType to specify the element type. This never fails.
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..01dd3ba10f 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def,
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
+ // Gradient op has "f" attr, which is set to the function we are getting
+ // gradient for. We need to functionalize the gradient function.
+ return true;
}
for (const auto& iter : node_def.attr()) {
@@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
+ // This is a SymbolicGradient op.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
+ results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
@@ -410,6 +411,21 @@ Status RewriteAssociatedFunction(
graph->RemoveNode(node);
break;
}
+ case AssociatedFunctionInfo::kSymbolicGradient: {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ GradientDef gradient_def;
+ gradient_def.set_function_name(func.name());
+ gradient_def.set_gradient_func(rewritten_function_name);
+ string original_grad_func = fld->FindGradient(func.name());
+ if (original_grad_func.empty()) {
+ TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
+ } else if (original_grad_func != rewritten_function_name) {
+ TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
+ }
+ break;
+ }
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 6065d0bb9a..53eab8b63e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -65,21 +65,33 @@ uint32 GetXLARandomSeed();
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
+ kFunctionAttr = 0,
+ kFunctionCallNode = 1,
+ kSymbolicGradient = 2,
};
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
// The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
+ static AssociatedFunctionInfo FunctionAttr(const string& func_name,
+ const AttrValueMap& attrs,
+ const string& attr_name) {
+ return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
+ }
+
+ // The node is a function call.
+ static AssociatedFunctionInfo FunctionCall(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
+ /*attr_name=*/"");
+ }
+
+ // The node is a SymbolicGradient op.
+ static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
+ /*attr_name=*/"");
+ }
AssociatedFunctionType type() const { return type_; }
@@ -90,6 +102,13 @@ class AssociatedFunctionInfo {
const AttrValueMap& attrs() const { return attrs_; }
private:
+ AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
+ const AttrValueMap& attrs, const string& attr_name)
+ : type_(type),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
@@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def,
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
+// and returned attrs will be this node's attributes;
+// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, FunctionLibraryRuntime* flr);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+// 2. For SymbolicGradient op, add or replace GradientDef in
+// FunctionLibraryDefinition;
+// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
index bda667eb1f..6354216eee 100644
--- a/tensorflow/compiler/tf2xla/type_util.h
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -25,6 +25,14 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
+// a TensorFlow DataType. The mapping from TF types to XLA types is not
+// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
+// inverse would not be a well-defined function. If you find that you want the
+// inverse mapping, then most likely you should be preserving the original
+// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
+// type.
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 739e47778a..b2c57e8880 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
std::unique_ptr<Graph> graph = GetGraph(fbody);
+ // Clear the "_kernel" attribute if it is set to "host". This is used to
+ // indicate that a computation should happen on the host instead of the
+ // accelerator, but doesn't make sense in XLA.
+ const char* const kKernelAttr = "_kernel";
+ for (Node* n : graph->nodes()) {
+ string value;
+ if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
+ n->ClearAttr(kKernelAttr);
+ }
+ }
+
// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
// they are added by the function body looked up. Therefore, they don't have
// core assignments here.
@@ -333,10 +344,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
}
// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
+// `args` is the list of input arguments, `retvals` is the list of retvals
+// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index 23d04d43b3..bc44301d40 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,21 +20,6 @@ limitations under the License.
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
- // TODO(b/26783907): The CPU backend currently does not implement sort.
- if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
- return false;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 2a9eaeee14..dd3498ef7a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
+Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
+ Tensor** output) {
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ if (expected_output_dtype(index) == DT_VARIANT) {
+ // tensor_data() is not supported for variant Tensor (i.e.,
+ // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
+ // XlaExpression inside the Tensor's tensor_data() does not work for
+ // variant. Instead construct a uint8 tensor and store the expression in its
+ // value.
+ // TODO(jpienaar): This should be refactored to stop masquerading
+ // XlaExpressions as Tensors.
+ *output = new Tensor();
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ context_->allocate_temp(DT_UINT8, tensor_shape, *output));
+ context_->set_output(index, **output);
+ } else {
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
+ TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
+ }
+ return Status::OK();
+}
+
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
- auto shape = builder()->GetShape(handle);
- if (!shape.ok()) {
- SetStatus(shape.status());
+ auto shape_or = builder()->GetShape(handle);
+ if (!shape_or.ok()) {
+ SetStatus(shape_or.status());
return;
}
- // The step's default allocator is the dummy XlaCompilationAllocator which
- // simply allocates a metadata buffer to hold the expression to which it
- // corresponds.
- TensorShape tensor_shape;
- OP_REQUIRES_OK(context_,
- XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
- context_->allocate_output(index, tensor_shape, &output));
+ allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index a3a0d10cc0..aa00a45496 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -255,6 +255,11 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
+ // Wraps OpKernelContext's allocate_output method while providing special
+ // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
+ // type to allow mapping for variant to more generic types.
+ Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
+
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index b0eeee3174..91d48125f1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible compile time constant inputs.";
return false;
}
+ if (x.is_metadata_op != y.is_metadata_op) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible values for is_metadata_op.";
+ return false;
+ }
return true;
}
@@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
return &it->second.front()->compile_time_constant_inputs;
}
+/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end() || it->second.empty()) {
+ return false;
+ }
+
+ // The test in IsCompatible ensures that if there are multiple matching
+ // registrations for this op name, they all have the same value of
+ // is_metadata_op, so only the first match is returned.
+ return it->second.front()->is_metadata_op;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
+ registration_->is_metadata_op = true;
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885f1f..4b2c2bacd6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 4> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kNumericTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BFLOAT16}};
+constexpr std::array<DataType, 11> kNumericTypes = {
+ {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
+ DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kCpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
-constexpr std::array<DataType, 10> kGpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
@@ -136,6 +137,10 @@ class XlaOpRegistry {
static const std::unordered_set<string>* CompileTimeConstantInputs(
const string& op);
+ // Returns true if `op` is a "metadata" op, one that only looks at the shapes
+ // of its operands and not their values.
+ static bool IsMetadataOp(const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -192,6 +197,10 @@ class XlaOpRegistry {
// Names of arguments that must be compile-time constants.
std::unordered_set<string> compile_time_constant_inputs;
+ // True if this is a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ bool is_metadata_op = false;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -256,6 +265,10 @@ class XlaOpRegistrationBuilder {
// Mark 'input_name' as an argument whose value must be known at compile-time.
XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
+ // Mark this op as a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ XlaOpRegistrationBuilder& IsMetadataOp();
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index ef70c1f8ac..cc7390c6e6 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -245,6 +245,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index f825f67b44..dc097f3696 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -220,6 +220,8 @@ cc_library(
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 25cc37edc4..ff0ec76a7f 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -97,13 +97,11 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
<< "Computation should have progran shape.";
auto program_shape = computation.proto().program_shape();
- // Create and run a program which produces a tuple with one element per
- // parameter, then return the tuple's constituent buffers.
- std::vector<Shape> param_shapes(program_shape.parameters().begin(),
- program_shape.parameters().end());
- auto fake_input_tuple =
- MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client);
- return client->DeconstructTuple(*fake_input_tuple).ValueOrDie();
+ std::vector<std::unique_ptr<GlobalData>> results;
+ for (const Shape& shape : program_shape.parameters()) {
+ results.push_back(MakeFakeDataOrDie(shape, client));
+ }
+ return results;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 95ff6432a5..6b31831010 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
+ case HloOpcode::kScatter:
+ // TODO(b/32495713): We aren't checking the embedded computation in
+ // Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
@@ -1278,7 +1281,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
absl::Span<const XlaOp> operands,
- const Shape& shape) {
+ const Shape& shape, const string& opaque) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1289,6 +1292,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
}
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
+ instr.set_custom_call_opaque(opaque);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -1785,9 +1789,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
+ return ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
@@ -1796,6 +1800,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1806,7 +1812,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
+ /*lhs_dilation=*/base_dilations,
+ /*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
@@ -2289,7 +2296,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
- tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
+ absl::flat_hash_set<int64> related_calls; // Related computations.
std::queue<int64> worklist;
worklist.push(root->id());
related_ops.insert(root->id());
@@ -2681,8 +2688,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape) {
- return builder->CustomCall(call_target_name, operands, shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
@@ -2795,10 +2803,12 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
- padding);
+ base_dilations, window_dilations, padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index d0c59fa6f2..2e14e47a35 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <type_traits>
#include <utility>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
@@ -34,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@@ -577,11 +577,9 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -673,6 +671,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
@@ -1029,7 +1029,7 @@ class XlaBuilder {
// A map from XlaOp::Handle to the index in the instructions_ vector where the
// instruction is held.
- tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+ absl::flat_hash_map<int64, int64> handle_to_index_;
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
@@ -1037,7 +1037,7 @@ class XlaBuilder {
std::map<int64, HloComputationProto> embedded_;
// The unique parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+ absl::flat_hash_set<int64> parameter_numbers_;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
@@ -1195,7 +1195,8 @@ class XlaBuilder {
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1246,6 +1247,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
friend XlaOp CrossReplicaSum(const XlaOp& operand,
absl::Span<const ReplicaGroup> replica_groups);
@@ -1717,12 +1720,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
-// Enqueues a custom call instruction onto the computation.
-// During code generation, a call instruction is emitted which targets a
-// symbol with the name |call_target_name|. The |operands| are passed to the
-// call instruction. |shape| is the resultant shape.
+// Enqueues a custom call instruction onto the computation. A custom call
+// invokes code external to XLA. The |operands| are passed to the external code,
+// and the external code is expected to produce a result of the given
+// |shape|. The exact mechanism is backend-specific. For example, in the CPU
+// backend, a call instruction is emitted which targets a symbol with the name
+// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
+// but |call_target_name| should be short as it may be used in labels. |opaque|
+// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque = "");
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1814,6 +1822,8 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index a472747bd1..0f9b591c70 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const {
return stream_;
}
+ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream(
+ stream_executor::Stream* stream) {
+ host_to_device_stream_ = stream;
+ return *this;
+}
+
+stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const {
+ return host_to_device_stream_;
+}
+
ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool(
const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
intra_op_thread_pool_ = intra_op_thread_pool;
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index 416131be00..ba3217f31b 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -65,6 +65,13 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
stream_executor::Stream* stream() const;
+ // If set, this is the stream to perform any pre-computation transfers on.
+ // The platform of the stream must match the platform the executable was
+ // built for. A value of nullptr indicates the option has not been set.
+ ExecutableRunOptions& set_host_to_device_stream(
+ stream_executor::Stream* stream);
+ stream_executor::Stream* host_to_device_stream() const;
+
// Sets the thread pool device on which to run Eigen subcomputations.
// Does not take ownership.
ExecutableRunOptions& set_intra_op_thread_pool(
@@ -90,6 +97,7 @@ class ExecutableRunOptions {
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
+ stream_executor::Stream* host_to_device_stream_ = nullptr;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 0d3136b0cc..3ed3afcfce 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// regression.
flags->set_xla_cpu_enable_fast_math(true);
flags->set_xla_gpu_enable_fast_math(true);
+
+ flags->set_xla_force_host_platform_device_count(1);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -323,6 +325,17 @@ void AllocateFlags() {
flag_values->xla_gpu_crash_on_verification_failures(),
"Crashes the program on extra verification failures, e.g. cuDNN "
"cross checking failures"),
+ tensorflow::Flag(
+ "xla_force_host_platform_device_count",
+ int32_setter_for(
+ &DebugOptions::set_xla_force_host_platform_device_count),
+ flag_values->xla_force_host_platform_device_count(),
+ "Force the host platform to pretend that there are these many "
+ "host \"devices\". All of these host devices are backed by the same"
+ "threadpool. Setting this to anything other than 1 can increase "
+ "overhead from context switching but we let the user override this "
+ "behavior to help run tests on the host that run models in parallel "
+ "across multiple devices."),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 5035f41988..656ce720a1 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return InvalidArgument("LiteralProto has no layout");
}
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
Literal literal(proto.shape());
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
@@ -725,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
switch (result_shape.element_type()) {
- case F32:
- return SliceInternal<float>(result_shape, start_indices);
+ case PRED:
+ return SliceInternal<bool>(result_shape, start_indices);
+ case U8:
+ return SliceInternal<uint8>(result_shape, start_indices);
+ case U16:
+ return SliceInternal<uint16>(result_shape, start_indices);
+ case U32:
+ return SliceInternal<uint32>(result_shape, start_indices);
+ case U64:
+ return SliceInternal<uint64>(result_shape, start_indices);
+ case S8:
+ return SliceInternal<int8>(result_shape, start_indices);
+ case S16:
+ return SliceInternal<int16>(result_shape, start_indices);
+ case S32:
+ return SliceInternal<int32>(result_shape, start_indices);
+ case S64:
+ return SliceInternal<int64>(result_shape, start_indices);
+ case F16:
+ return SliceInternal<half>(result_shape, start_indices);
case BF16:
return SliceInternal<bfloat16>(result_shape, start_indices);
+ case F32:
+ return SliceInternal<float>(result_shape, start_indices);
+ case F64:
+ return SliceInternal<double>(result_shape, start_indices);
case C64:
return SliceInternal<complex64>(result_shape, start_indices);
- case S32:
- return SliceInternal<int32>(result_shape, start_indices);
- case U32:
- return SliceInternal<uint32>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());
@@ -1850,6 +1870,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+ if (LayoutUtil::IsSparseArray(subshape())) {
+ // Compute the number of elements (indices) in the sparse shape and reserve
+ // the necessary space in spare_indices.
+ TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0)
+ << "Scalar shapes cannot be sparse";
+ TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0)
+ << "Unexpected number of indices in proto ("
+ << proto.sparse_indices_size() << ") for shape of rank "
+ << ShapeUtil::Rank(subshape());
+ const int64 index_count =
+ proto.sparse_indices_size() / ShapeUtil::Rank(subshape());
+ sparse_indices()->Resize(index_count);
+
+ // Copy the indices from the proto into the SparseIndexArray object.
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
+ proto.sparse_indices()));
+ }
+
switch (subshape().element_type()) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
@@ -1907,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
}
} break;
case TUPLE:
- LOG(FATAL) << "Should not be called on tuple shapes: "
- << ShapeUtil::HumanString(subshape());
- break;
+ return InvalidArgument("Should not be called on tuple shapes: %s",
+ ShapeUtil::HumanString(subshape()));
default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ return InvalidArgument("Is called on unsupported shape: %s",
+ ShapeUtil::HumanString(subshape()));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 1e0a2ad0dd..3cd3541fe1 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -203,6 +203,10 @@ class LiteralBase {
// Returns the count of the elements in the array at the given shape index in
// this literal.
int64 element_count(const ShapeIndex& index = {}) const {
+ if (index.empty()) {
+ // Common case, avoid GetSubshape().
+ return ShapeUtil::ElementsIn(shape());
+ }
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
@@ -852,9 +856,9 @@ class BorrowingLiteral : public LiteralBase {
template <typename NativeT>
absl::Span<const NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
@@ -865,9 +869,9 @@ absl::Span<const NativeT> LiteralBase::Piece::data() const {
template <typename NativeT>
absl::Span<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 7ad287c897..dd5b54e4c9 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) {
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
+
+ // Serialize then deserialize and verify the resulting literal.
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
+ Literal::CreateFromProto(literal.ToProto()));
+
+ EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
+ absl::Span<const int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(literal_from_proto.data<int64>(),
+ absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0d2d..ffa336f304 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
@@ -530,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
}
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5911..43332e0abd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@ class LocalComputationBuilder {
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
@@ -277,6 +278,8 @@ class LocalComputationBuilder {
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366ff07..f8197488fb 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -995,7 +995,30 @@ class ComputationBuilder(object):
window_strides)
return self._client.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.c_local_computation,
- window_dimensions, window_strides, pads)
+ window_dimensions, window_strides, (), (), pads)
+
+ def ReduceWindowWithGeneralPadding(
+ self, operand, init_value, computation_to_apply, window_dimensions,
+ window_strides, base_dilations, window_dilations, padding):
+ """Enqueues a windowed reduction operation onto the computation.
+
+ Args:
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
+ computation_to_apply: a binary reduction function (Computation).
+ window_dimensions: dimensions of window (sequence of integers).
+ window_strides: strides for window (sequence of integers).
+ base_dilations: dilations for the base (sequence of integers).
+ window_dilations: dilations for window (sequence of integers).
+ padding: length-N array-like of pairs of integers of (low, high) padding.
+
+ Returns:
+ A LocalOp representing the added ReduceWindow op.
+ """
+ return self._client.ReduceWindowWithGeneralPadding(
+ operand, init_value, computation_to_apply.c_local_computation,
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
@@ -1109,7 +1132,7 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1140,7 @@ class ComputationBuilder(object):
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1149,11 @@ class ComputationBuilder(object):
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1163,7 @@ class ComputationBuilder(object):
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1171,8 @@ class ComputationBuilder(object):
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1190,8 @@ class ComputationBuilder(object):
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1218,7 @@ class ComputationBuilder(object):
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1244,8 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19457..82103f0313 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 97fcd37f6b..3abb3855a4 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -34,19 +34,28 @@ cc_library(
],
)
-tf_cc_binary(
- name = "grpc_service_main_cpu",
+cc_library(
+ name = "grpc_service_main_library",
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings:str_format",
],
)
+tf_cc_binary(
+ name = "grpc_service_main_cpu",
+ deps = [
+ ":grpc_service_main_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
tf_cc_test(
name = "grpc_client_test",
srcs = ["grpc_client_test.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index d6b5149a24..522ab99fb1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -29,8 +30,15 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
+ bool any_address = false;
+ string platform_str;
std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("port", &port, "port to listen on"),
+ tensorflow::Flag("platform", &platform_str,
+ "The XLA platform this service should be bound to"),
+ tensorflow::Flag("port", &port, "The TCP port to listen on"),
+ tensorflow::Flag(
+ "any", &any_address,
+ "Whether to listen to any host address or simply localhost"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ se::Platform* platform = nullptr;
+ if (!platform_str.empty()) {
+ platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
+ }
std::unique_ptr<xla::GRPCService> service =
- xla::GRPCService::NewService().ConsumeValueOrDie();
+ xla::GRPCService::NewService(platform).ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(absl::StrFormat("localhost:%d", port));
+ string server_address(
+ absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port));
+ builder.SetMaxReceiveMessageSize(INT_MAX);
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<::grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
-
return 0;
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fb80c78f68..2b292ed053 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -146,6 +146,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -182,6 +184,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -251,6 +254,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -296,6 +300,7 @@ cc_library(
"hlo_opcode.cc",
"hlo_schedule.cc",
"hlo_sharding.cc",
+ "hlo_sharding_metadata.cc",
],
hdrs = [
"dfs_hlo_visitor.h",
@@ -309,6 +314,7 @@ cc_library(
"hlo_opcode.h",
"hlo_schedule.h",
"hlo_sharding.h",
+ "hlo_sharding_metadata.h",
],
deps = [
":hlo_casting_utils",
@@ -333,6 +339,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -365,8 +373,11 @@ cc_library(
hdrs = ["pattern_matcher.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
@@ -392,6 +403,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)
@@ -482,6 +494,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -590,6 +604,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
@@ -772,6 +787,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -899,6 +915,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -948,6 +965,8 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -983,6 +1002,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
@@ -1030,6 +1051,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -1083,6 +1106,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
@@ -1121,6 +1145,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@@ -1142,6 +1168,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -1166,6 +1193,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_module_group",
+ ":hlo_module_group_metadata",
":hlo_parser",
":hlo_proto",
"//tensorflow/compiler/xla:test",
@@ -1191,6 +1219,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
@@ -1211,6 +1240,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -1255,6 +1286,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -1275,6 +1308,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -1290,15 +1324,25 @@ cc_library(
)
cc_library(
+ name = "fusion_queue",
+ hdrs = ["fusion_queue.h"],
+ deps = [
+ ":hlo",
+ ],
+)
+
+cc_library(
name = "instruction_fusion",
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
+ ":fusion_queue",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -1325,6 +1369,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -1380,6 +1426,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
@@ -1635,6 +1682,8 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@@ -1666,6 +1715,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -1792,42 +1842,6 @@ tf_cc_test(
)
cc_library(
- name = "inliner",
- srcs = ["inliner.cc"],
- hdrs = ["inliner.h"],
- deps = [
- ":hlo",
- ":hlo_pass",
- ":hlo_query",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-tf_cc_test(
- name = "inliner_test",
- srcs = ["inliner_test.cc"],
- deps = [
- ":cpu_plugin",
- ":hlo",
- ":hlo_matchers",
- ":inliner",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/memory",
- ],
-)
-
-cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
@@ -2038,6 +2052,7 @@ cc_library(
":logical_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -2073,6 +2088,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -2094,6 +2110,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2177,6 +2194,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2198,6 +2216,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@@ -2258,6 +2278,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2314,6 +2336,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2340,6 +2364,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2411,6 +2437,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -2423,6 +2450,7 @@ tf_cc_test(
":hlo",
":hlo_parser",
":hlo_verifier",
+ ":layout_assignment",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -2455,6 +2483,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2557,6 +2587,7 @@ cc_library(
],
deps = [
":hlo",
+ ":hlo_module_group",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -2582,12 +2613,34 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)
+tf_cc_test(
+ name = "hlo_pass_pipeline_test",
+ srcs = ["hlo_pass_pipeline_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_parser",
+ ":hlo_pass_pipeline",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_cse",
srcs = ["hlo_cse.cc"],
@@ -2601,6 +2654,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2675,27 +2729,13 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
cc_library(
- name = "hlo_sharding_metadata",
- srcs = ["hlo_sharding_metadata.cc"],
- hdrs = [
- "hlo_sharding_metadata.h",
- ],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:shape_tree",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
name = "hlo_domain_verifier",
srcs = ["hlo_domain_verifier.cc"],
hdrs = ["hlo_domain_verifier.h"],
@@ -2745,7 +2785,6 @@ tf_cc_test(
":hlo_domain_isolator",
":hlo_domain_remover",
":hlo_parser",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -3121,6 +3160,7 @@ cc_library(
":hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -3243,6 +3283,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3272,6 +3314,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3328,6 +3371,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -3355,7 +3400,6 @@ cc_library(
deps = [
":hlo",
":hlo_lexer",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@@ -3413,6 +3457,39 @@ cc_library(
deps = ["//tensorflow/core:lib"],
)
+cc_library(
+ name = "map_inliner",
+ srcs = ["map_inliner.cc"],
+ hdrs = ["map_inliner.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "map_inliner_test",
+ srcs = ["map_inliner_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":map_inliner",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/memory",
+ ],
+)
+
tf_cc_test(
name = "hlo_casting_utils_test",
srcs = ["hlo_casting_utils_test.cc"],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4ef1dffa73..86d9dbea90 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -754,11 +754,12 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
};
auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ hlo = as_type(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo));
}
- return as_type(hlo, dot->shape().element_type());
+ return hlo;
};
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
@@ -2056,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return Status::OK();
}
+ // Bail on dilation.
+ if (window_util::HasDilation(window)) {
+ VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
+ return Status::OK();
+ }
+
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
<< (convert != nullptr
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index b864c372fa..9f8d0ee88b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// A pass which performs algebraic simplifications.
-class AlgebraicSimplifier : public HloPassInterface {
+class AlgebraicSimplifier : public HloModulePass {
public:
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
// bitcast from 'from_shape' to 'to_shape' after considering platform
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2427..2047f894b4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P(
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
- ::testing::tuple<int, int, int, bool, bool>> {};
+ ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
int m, k, n;
bool transpose_lhs, transpose_rhs;
- std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+ PrimitiveType element_type;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
- Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
- Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
- Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
- Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P(
DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Values(1, 2), ::testing::Bool(),
- ::testing::Bool()));
+ ::testing::Bool(), ::testing::Values(F32, BF16)));
struct DotOfConcatTestSpec {
int64 m;
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index a7d8927cf7..43feccee3c 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -110,7 +111,7 @@ class AllocationTracker {
// A map from device memory opaque value to allocation. One such map is
// maintained per device ordinal.
- using AllocationMap = tensorflow::gtl::FlatMap<const void*, Allocation>;
+ using AllocationMap = absl::flat_hash_map<const void*, Allocation>;
tensorflow::mutex mutex_;
@@ -123,10 +124,7 @@ class AllocationTracker {
int64 next_handle_ GUARDED_BY(mutex_);
// A map from device ordinal to AllocationMap.
- //
- // This is not a TF FlatMap because (currently) FlatMap (and therefore
- // AllocationMap) is not movable.
- std::unordered_map<int, AllocationMap> opaque_to_allocation_map_
+ absl::flat_hash_map<int, AllocationMap> opaque_to_allocation_map_
GUARDED_BY(mutex_);
// A map from data handle to a vector of shaped buffers that represent the
@@ -146,7 +144,7 @@ class AllocationTracker {
// non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
// free'd when both the view *and* the original tuple are Unregistered. This
// refcounting is managed in opaque_to_allocation_map_.
- tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
+ absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
handle_to_shaped_buffers_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index 79d37f08d3..5b625bf3b9 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -25,7 +25,7 @@ namespace xla {
// Normally these would live in the algebraic simplifier, but we want to run
// this to fixpoint (this pass reaches fixed point in one execution) before we
// run the DotDecomposer.
-class BatchDotSimplification : public HloPassInterface {
+class BatchDotSimplification : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override;
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 30d33e0d35..f70f6ddfec 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 76e32174f3..147f3ae7b6 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -26,7 +26,7 @@ namespace xla {
// A pass which rewrites batch norm operations into more operations. Breaking a
// big operation into smaller operations helps leverage our generic fusion
// logic.
-class BatchNormExpander : public HloPassInterface {
+class BatchNormExpander : public HloModulePass {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index 5dcd31b83d..cb3d12f0bf 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -31,7 +31,7 @@ namespace xla {
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
-class BFloat16ConversionFolding : public HloPassInterface {
+class BFloat16ConversionFolding : public HloModulePass {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 30b6346312..f48e925823 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
-class BFloat16Normalization : public HloPassInterface {
+class BFloat16Normalization : public HloModulePass {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
@@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface {
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
-class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+class BFloat16MixedPrecisionRemoval : public HloModulePass {
public:
BFloat16MixedPrecisionRemoval() {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 58f78f8e24..002be9c970 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
};
auto root = fusion->fused_instructions_computation()->root_instruction();
- tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
+ absl::flat_hash_set<const HloValue*> changed_root_buffers;
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
@@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
+ absl::flat_hash_set<const HloComputation*>* visited_computations) {
bool parameter_changed = false;
auto insts = computation->MakeInstructionPostOrder();
// Do the adjustment on each instruction in the computation in reverse
@@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
// another input parameter. A fixed point will be reached because the
// parameters can only be changed from BF16 to F32, not the other way
// around.
- tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
+ absl::flat_hash_set<const HloComputation*> visited_in_while;
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
&visited_in_while) ||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
@@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
const auto& computations_topological_order =
module->MakeComputationPostOrder();
- tensorflow::gtl::FlatSet<const HloComputation*> resolved;
+ absl::flat_hash_set<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
comp_it != computations_topological_order.rend(); ++comp_it) {
if (ContainsKey(resolved, *comp_it)) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 1ee64971ab..5fcaa15c83 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -58,7 +60,7 @@ namespace xla {
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
-class BFloat16Propagation : public HloPassInterface {
+class BFloat16Propagation : public HloModulePass {
public:
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
@@ -81,7 +83,7 @@ class BFloat16Propagation : public HloPassInterface {
// The set of instructions to consider using bfloat16, computed in the forward
// pass.
- tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
+ absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
// ***************************
// Functions called and state produced by the backward pass (from root to
@@ -110,12 +112,12 @@ class BFloat16Propagation : public HloPassInterface {
// The set of HloInstructions that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloInstruction*>
+ absl::flat_hash_set<const HloInstruction*>
instructions_visited_in_backward_pass_;
// The set of HloComputations that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloComputation*>
+ absl::flat_hash_set<const HloComputation*>
computations_visited_in_backward_pass_;
// ***************************
@@ -131,7 +133,7 @@ class BFloat16Propagation : public HloPassInterface {
// point is reached.
bool ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
+ absl::flat_hash_set<const HloComputation*>* visited_computations);
// Makes the parameters of called computations match how they are called by
// the given HLO.
@@ -182,11 +184,11 @@ class BFloat16Propagation : public HloPassInterface {
PrimitiveType target_type);
// The set of F32 HLO values that must be kept in F32.
- tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
+ absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
// Mapping from each HloComputation to the number of callers to it in the
// module. Populated at the beginning of this pass.
- tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
+ absl::flat_hash_map<const HloComputation*, int64> caller_counts_;
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
// are subject to further adjustment, then finally applied to the HLOs. This
@@ -195,8 +197,7 @@ class BFloat16Propagation : public HloPassInterface {
//
// For each HloInstruction, changes_to_bf16_ stores the affected buffers in
// the output as a map from in-place pointers to subshapes to shape indices.
- tensorflow::gtl::FlatMap<HloInstruction*,
- tensorflow::gtl::FlatMap<Shape*, ShapeIndex>>
+ absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
changes_to_bf16_;
// Whether the last processed HLO module has been changed by this pass.
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc
index 23645346e6..5b48f10505 100644
--- a/tensorflow/compiler/xla/service/bfloat16_support.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_support.cc
@@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
const HloInstruction& hlo, int64 operand_index) {
switch (hlo.opcode()) {
case HloOpcode::kAbs:
+ case HloOpcode::kAllToAll:
case HloOpcode::kBroadcast:
case HloOpcode::kClamp:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 65fa951afe..2c2d1626c2 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <ostream>
#include <utility>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -41,10 +43,10 @@ limitations under the License.
namespace xla {
namespace {
+using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::StrAppend;
using absl::StrAppendFormat;
-using ::tensorflow::gtl::FlatMap;
-using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::HumanReadableNumBytes;
template <typename T>
@@ -128,8 +130,8 @@ Status GatherComputationsByAllocationType(
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- FlatSet<const HloComputation*> thread_local_set;
- FlatSet<const HloComputation*> global_set;
+ flat_hash_set<const HloComputation*> thread_local_set;
+ flat_hash_set<const HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
@@ -444,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex(
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const {
using SliceSet =
- FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
+ flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
// assigned slice, returns the empty set.
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
@@ -519,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
// BufferAllocation.
void BufferAssignment::CombineTempAllocations() {
VLOG(1) << "CombineTempAllocations()";
- FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher>
+ flat_hash_map<LogicalBuffer::Color, BufferAllocation,
+ LogicalBuffer::Color::Hasher>
combined_allocation_map;
// Move all temp allocations into a single run at the end of the allocations
@@ -582,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() {
}
// Update allocation indices to their new positions.
- allocation_index_for_buffer_.clear_no_resize();
+ allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
+ allocation_index_for_buffer_.end());
for (size_t index = 0; index < allocations_.size(); ++index) {
BufferAllocation* allocation = &allocations_[index];
allocation->set_index(index);
@@ -812,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
Status BufferAssigner::AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const FlatSet<const LogicalBuffer*>& colocated_buffers,
- const FlatSet<BufferAllocation::Index>& colocated_allocations,
- FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
+ const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
@@ -833,7 +837,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// Generate a post order sort of instructions for sorting of the
// LogicalBuffers.
- FlatMap<const HloInstruction*, int> post_order_position;
+ flat_hash_map<const HloInstruction*, int> post_order_position;
int position = 0;
for (auto* instruction : computation->MakeInstructionPostOrder()) {
post_order_position.emplace(instruction, position);
@@ -850,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// buffers_to_assign_sequentially map, even if we end up with an empty set
// of buffers. This ensures we can correctly determine whether to run
// whole-module heap simulation.
- buffers_to_assign_sequentially->emplace(computation,
- FlatSet<const LogicalBuffer*>());
+ buffers_to_assign_sequentially->emplace(
+ computation, flat_hash_set<const LogicalBuffer*>());
}
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
@@ -1043,12 +1047,12 @@ Status BufferAssigner::AssignBuffersForComputation(
return Status::OK();
}
-FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
+flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
BufferAssigner::SplitBuffersByColor(
- const FlatSet<const LogicalBuffer*>& buffers) {
- FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
+ const flat_hash_set<const LogicalBuffer*>& buffers) {
+ flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
color_map;
for (auto buffer : buffers) {
color_map[buffer->color()].insert(buffer);
@@ -1057,23 +1061,38 @@ BufferAssigner::SplitBuffersByColor(
}
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
- const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
+ const flat_hash_map<const HloComputation*,
+ flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic
// that seems to give the best results is lazy-best-fit, with all runs of
// alloc / free calls sorted in decreasing size order.
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
+
+ // Returns a heap algorithm that chooses the best result from several
+ // algorithms.
+ auto get_heap_algorithm = [&](int64 alignment) {
+ auto algorithms =
+ absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
+ algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)));
+ algorithms->push_back(
+ absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
+ return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
+ };
+
if (run_whole_module_heap_simulation) {
// Run the heap simulation over the whole module. This reduces memory usage,
// since buffers for kCall, kWhile, and kConditional sub-computations are
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
HloSchedule schedule(&assignment->module());
- FlatSet<const LogicalBuffer*> all_buffers_to_assign;
+ flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1093,8 +1112,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
+ HeapSimulator::Run(get_heap_algorithm(alignment),
assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
@@ -1108,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
VLOG(1) << "Running per-computation heap simulation";
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1123,12 +1142,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(
- absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
- *computation, HloInstructionSequence(*instruction_sequence),
- assignment->points_to_analysis(), assignment->buffer_size_,
- options));
+ HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
+ HloInstructionSequence(*instruction_sequence),
+ assignment->points_to_analysis(),
+ assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
@@ -1145,9 +1162,8 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
// Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
// buffers in this allocation.
- tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
- id_to_buffer;
- tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
+ absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
+ absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
for (const auto& pair : allocation.assigned_buffers()) {
const LogicalBuffer* buffer = pair.first;
const BufferAllocation::OffsetSize& offset_size = pair.second;
@@ -1186,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
// Next gather the set of logical buffers live at the earliest point of
// maximal live set size.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
+ absl::flat_hash_set<const LogicalBuffer*> live_buffers;
live_size = 0;
for (const auto& event : heap_trace.events()) {
const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
@@ -1576,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets(
void BufferAssigner::AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- FlatSet<const LogicalBuffer*>* colocated_buffers,
- FlatSet<BufferAllocation::Index>* colocated_allocations) {
+ flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
BufferAllocation* allocation = nullptr;
// Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
@@ -1650,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Once b/32491382 enables module-level liveness analysis, we may be able
// to assign colocated buffers (or at least reuse their allocation for
// buffers outside of the set) in AssignBuffersForComputation.
- FlatSet<const LogicalBuffer*> colocated_buffers;
- FlatSet<BufferAllocation::Index> colocated_allocations;
+ flat_hash_set<const LogicalBuffer*> colocated_buffers;
+ flat_hash_set<BufferAllocation::Index> colocated_allocations;
std::vector<ColocatedBufferSet> colocated_buffer_sets;
BuildColocatedBufferSets(module, assignment->liveness(),
assignment->buffer_size_, &colocated_buffer_sets);
@@ -1669,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// First assign buffers for global computatations. Temporary buffers for
// sequential computations are collected in 'buffers_to_assign_sequentially'.
- FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
buffers_to_assign_sequentially;
for (auto* computation : global_computations) {
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 24ba7c16f5..899cd36e1f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -22,6 +22,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
@@ -33,8 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -148,7 +148,7 @@ class BufferAllocation {
// Access to the logical buffers assigned to this allocation, and their
// associated logical offsets and sizes.
- const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>&
+ const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>&
assigned_buffers() const {
return assigned_buffers_;
}
@@ -323,7 +323,7 @@ class BufferAllocation {
// Mapping from the set of buffers assigned to this allocation to their
// logical offsets and sizes.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
+ absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_;
int64 fragmentation_bytes_ = 0;
std::vector<HeapSimulatorTrace> heap_traces_;
@@ -500,7 +500,7 @@ class BufferAssignment {
int64 temp_allocation_total_size_ = 0;
// Maps Buffers to the index of the BufferAllocation which holds the buffer.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index>
+ absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index>
allocation_index_for_buffer_;
const HloModule* module_;
@@ -554,11 +554,10 @@ class BufferAssigner {
// true.
Status AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
- const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
- colocated_allocations,
- tensorflow::gtl::FlatMap<const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
+ const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
+ absl::flat_hash_map<const HloComputation*,
+ absl::flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment);
@@ -568,9 +567,8 @@ class BufferAssigner {
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
// assuming all global computations are sequentially ordered.
Status AssignBuffersWithSequentialOrdering(
- const tensorflow::gtl::FlatMap<
- const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
+ const absl::flat_hash_map<const HloComputation*,
+ absl::flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
@@ -590,7 +588,7 @@ class BufferAssigner {
// alias. Explicitly handling these colocated buffers is necessary because
// points-to analysis is computation level scope and does not recognize
// aliasing across computations (b/32491382).
- using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
+ using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
// Returns a vector of ColocatedBufferSet objects, where each
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
@@ -605,8 +603,8 @@ class BufferAssigner {
void AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
- tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
+ absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
// the invariant that all sets in 'colocated_buffer_sets' are disjoint.
@@ -624,11 +622,10 @@ class BufferAssigner {
// Split a set of buffers into several sets, each of which contains buffers
// colored with the same color.
- tensorflow::gtl::FlatMap<LogicalBuffer::Color,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
- SplitBuffersByColor(
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
+ absl::flat_hash_map<LogicalBuffer::Color,
+ absl::flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
+ SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
// If true, buffer assignments assumes that input parameter buffers and output
// buffers can be shared if their sizes match.
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h
index cdd3cf4032..f939a426ea 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.h
+++ b/tensorflow/compiler/xla/service/buffer_liveness.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -27,8 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -102,7 +101,7 @@ class BufferLiveness {
// Set of LogicalBuffers which are aliased in the output of other
// instructions. For example, a LogicalBuffer which is inserted into a tuple
// is considered to be aliased and will be in this set.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
+ absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
// LogicalBuffers that may be live out of the entry computation.
PointsToSet::BufferSet maybe_live_out_buffers_;
diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h
index 305914fca8..cc46af5eee 100644
--- a/tensorflow/compiler/xla/service/buffer_value_containers.h
+++ b/tensorflow/compiler/xla/service/buffer_value_containers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet(
return output;
}
-using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>;
+using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>;
template <class LogicalBufferContainerT>
BufferValueFlatSet ToBufferValueFlatSet(
const LogicalBufferContainerT& logical_buffer_container) {
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 23b2a32709..bdd5069632 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <queue>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
bool CallGraph::DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
+ absl::flat_hash_set<const HloComputation*>* visited) const {
if (a == b || ContainsKey(*visited, b)) {
// The call graph is guaranteed to be acyclic so any previously visited node
// we encounter was already determined to be dominated.
@@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper(
bool CallGraph::Dominates(const HloComputation* a,
const HloComputation* b) const {
- tensorflow::gtl::FlatSet<const HloComputation*> visited;
+ absl::flat_hash_set<const HloComputation*> visited;
return DominatesHelper(a, b, &visited);
}
@@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
Status CallGraph::VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
+ absl::flat_hash_set<const CallGraphNode*>* visited) const {
auto pair = visited->insert(&node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
@@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal(
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
bool visit_unreachable_nodes) const {
- tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
+ absl::flat_hash_set<const CallGraphNode*> visited;
if (visit_unreachable_nodes) {
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 3af2ab5edf..cb56f4789d 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <ostream>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -145,19 +145,19 @@ class CallGraphNode {
// The computations called by this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callees_;
- tensorflow::gtl::FlatSet<HloComputation*> callee_set_;
+ absl::flat_hash_set<HloComputation*> callee_set_;
// The computations which call this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callers_;
- tensorflow::gtl::FlatSet<HloComputation*> caller_set_;
+ absl::flat_hash_set<HloComputation*> caller_set_;
// The call sites in this computation
std::vector<CallSite> callsites_;
// The map from instruction to index in callsites_ for looking up the callsite
// (if any) associated with a particular instruction in this computation.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_;
+ absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_;
// The call sites in other computations which call this computation.
std::vector<CallSite> caller_callsites_;
@@ -250,14 +250,14 @@ class CallGraph {
// 'visited'.
Status VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
+ absl::flat_hash_set<const CallGraphNode*>* visited) const;
// Recursive helper for computing whether 'a' dominates 'b' in the call
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
// and 'visited' is the set of computations which have been visited.
bool DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
+ absl::flat_hash_set<const HloComputation*>* visited) const;
// The HLO module represented by this call graph.
const HloModule* module_ = nullptr;
@@ -267,7 +267,7 @@ class CallGraph {
// Map from HLO computation to the index of the corresponding call graph node
// in nodes_.
- tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_;
+ absl::flat_hash_map<const HloComputation*, int64> node_indices_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c5cd88b9ea..08c4aff4f7 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -25,7 +25,7 @@ namespace xla {
// For every kCall operation in the main computation, we inline the body of the
// called function, and proceed recursively.
-class CallInliner : public HloPassInterface {
+class CallInliner : public HloModulePass {
public:
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 3de50cbd7f..2223ad6753 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that removes kConditional with a constant predicate, replacing them
// with their true or false computation as appropriate.
-class ConditionalSimplifier : public HloPassInterface {
+class ConditionalSimplifier : public HloModulePass {
public:
absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index 498894737f..ce0138e56f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which rewrites convolutions with feature_group_count > 1 into
// convolutions with feature_group_count = 1.
-class ConvolutionFeatureGroupConverter : public HloPassInterface {
+class ConvolutionFeatureGroupConverter : public HloModulePass {
public:
ConvolutionFeatureGroupConverter() {}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index b65dfef9c9..f35324aa35 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
@@ -31,8 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -432,7 +432,7 @@ class CopyRemover {
// Construct a list for each HLO buffer in the alias analysis. Maintain a
// map from HloValue to the respective list element representing that
// value. The map is used to construct the copy info map below.
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
+ absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
for (const HloBuffer& buffer : alias_analysis.buffers()) {
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
@@ -480,7 +480,7 @@ class CopyRemover {
// respective ValueNode representing that value.
void AddValueList(
absl::Span<const HloValue* const> values,
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
+ absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
ValueNode* tail = nullptr;
ValueNode* head = nullptr;
for (const HloValue* value : values) {
@@ -516,8 +516,7 @@ class CopyRemover {
// respective ValueNode.
void CreateCopyMap(
const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
- value_to_node) {
+ const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
for (HloComputation* computation : module.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
// Add copies with unambiguous source values to the map. Copies with
@@ -905,7 +904,7 @@ class CopyRemover {
// The heads of all the value lists. Each value list represents the HLO
// values contained in a particular HLO buffer. The values in the list are
// in dependency order.
- tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
+ absl::flat_hash_set<const ValueNode*> value_lists_;
// Copy removal requires fast access to the value list elements
// corresponding to the source and destination values of the kCopy
@@ -916,7 +915,7 @@ class CopyRemover {
ValueNode* src = nullptr;
ValueNode* dest = nullptr;
};
- tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
+ absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
};
HloModule* module_;
@@ -1010,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloInstruction* root = computation->root_instruction();
// Mark nondistinct/ambiguous indices.
- tensorflow::gtl::FlatSet<const HloBuffer*> seen;
+ absl::flat_hash_set<const HloBuffer*> seen;
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const HloBuffer*> buffers_at_index =
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index d308f6bc84..c097089e30 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -43,7 +43,7 @@ namespace xla {
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
-class CopyInsertion : public HloPassInterface {
+class CopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a59e..58abb330a6 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -93,6 +94,7 @@ cc_library(
":target_machine_features",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
+ "//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
@@ -126,7 +128,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
- "//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
@@ -180,6 +181,7 @@ cc_library(
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
+ ":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@@ -288,6 +290,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -307,6 +311,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@llvm//:analysis",
"@llvm//:target",
],
@@ -461,12 +466,16 @@ cc_library(
],
copts = runtime_copts(),
deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
@@ -624,6 +633,18 @@ cc_library(
)
cc_library(
+ name = "runtime_key_value_sort",
+ srcs = ["runtime_key_value_sort.cc"],
+ hdrs = ["runtime_key_value_sort.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],
hdrs = ["runtime_fork_join.h"],
@@ -745,6 +766,7 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 59437e88af..becee3f81f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -31,7 +31,7 @@ namespace cpu {
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
-class ConvCanonicalization : public HloPassInterface {
+class ConvCanonicalization : public HloModulePass {
public:
explicit ConvCanonicalization(
const TargetMachineFeatures* target_machine_features)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 18fc144efe..68c715a086 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -86,8 +86,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
-#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
@@ -249,9 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
- // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
- // where we will take this pass in future.
- // pipeline.AddPass<Inliner>();
+ pipeline.AddPass<MapInliner>();
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
// pass.
@@ -308,7 +306,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), target_machine_features);
+ module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
return pipeline.Run(module).status();
}
@@ -328,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplification after layout assignement");
- pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ // TODO(b/117156505): When the bug is fixed, the CPU backend should not
+ // produce layout changing elementwise operations. We will then pass
+ // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to
+ // enable stricter verification.
+ pass.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index d49f7d7cc2..076235f887 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -30,7 +30,7 @@ namespace xla {
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
-class CpuCopyInsertion : public HloPassInterface {
+class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 6af724b2a5..a39a9d4724 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class CpuHloSupportChecker : public HloPassInterface {
+class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index bfecbd6e01..c291bf2d1b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <numeric>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
@@ -38,7 +39,7 @@ using absl::nullopt;
using absl::optional;
using ShouldMakeOperandColMajorCache =
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
+ absl::flat_hash_map<const HloInstruction*, bool>;
} // namespace
static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index 3c4fe68b83..f4da35dd37 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment {
public:
explicit CpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
const TargetMachineFeatures* target_machine_features)
- : LayoutAssignment(entry_computation_layout),
+ : LayoutAssignment(entry_computation_layout,
+ std::move(instruction_can_change_layout_func)),
target_machine_features_(*target_machine_features) {}
~CpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 4668f3872d..97659b88a7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(
+ entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ &target_machine_features);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -321,8 +322,9 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ &target_machine_features);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c384bb..a9febe891b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,19 +17,29 @@ limitations under the License.
#include <functional>
+#include "absl/container/flat_hash_map.h"
+#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace cpu {
namespace runtime {
-XfeedManager* GetXfeedManager() {
- static XfeedManager* manager = new XfeedManager;
- return manager;
+XfeedManager* GetXfeedManager(int device_ordinal) {
+ static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
+ static absl::Mutex* mutex = new absl::Mutex();
+
+ absl::MutexLock lock(mutex);
+ auto it = managers->find(device_ordinal);
+ if (it == managers->end()) {
+ it = managers->emplace(device_ordinal, new XfeedManager()).first;
+ }
+ return it->second;
}
extern const char* const kEigenMatMulF16SymbolName =
@@ -74,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+ "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
@@ -94,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
- const void* shape,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireInfeedBufferForDequeue: "
- << ShapeString(shape, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireInfeedBufferForDequeue: "
+ << ShapeString(shape, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->infeed()->BlockingDequeueBuffer();
@@ -114,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
@@ -130,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireOutfeedBufferForPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->outfeed()->BlockingDequeueBuffer();
@@ -150,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967123..b2e760a224 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -26,6 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/types.h"
@@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
-// Returns the infeed manager used by the CPU runtime.
-XfeedManager* GetXfeedManager();
+// Returns the infeed manager used by the CPU runtime for the CPU device
+// `device_ordinal`. Note the device ordinal does not name a CPU
+XfeedManager* GetXfeedManager(int device_ordinal);
} // namespace runtime
} // namespace cpu
@@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager();
extern "C" {
+// Some things common to all of the runtime entry points below:
+//
+// * The shape pointer and shape_length reflect values that can be deserialized
+// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass
+// reified type information from the generated program to the runtime, which
+// helps check the type safety and contract for the emitted-code/runtime
+// communication.
+//
+// * run_options is used to look up the device ordinal for the stream executor
+// we're executing under. If it is null the device ordinal is assumed to be
+// 0 (this behavior helps in writing tests).
+
// Note: in the runtime entry points below, the shape pointer and shape_length
// reflect values that can be deserialized via
// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified
@@ -89,7 +115,8 @@ extern "C" {
// the length would be more exact, but the length check is chosen as a
// tradeoff between error checking and speed/simplicity.
extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length);
// Relinquishes the next infeed buffer that was returned by
// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
@@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
// implemented we will add support for multiple outstanding buffers
// that can be returned out of order.
extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
// Blocks until the next outfeed buffer is available to be populated, then
// returns it.
extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length);
// Relinquishes the outfeed buffer after it has been populated.
// buffer_ptr must have been previously returned by
@@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
// acquired, i.e., there may only be one outstanding outfeed buffer in
// use by the runtime.
extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
} // extern "C"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 5519a43b2f..1cc2844470 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
@@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(
buffers.push_back(buffer);
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
cleanup.release();
@@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, size, source));
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
return Status::OK();
@@ -265,7 +268,8 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
buffer_pointers.push_back(b.get());
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
VLOG(2) << "Waiting for buffer to be notified as populated.";
std::vector<Shape> outfed_shapes;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a636b..b2abdb39a5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -24,6 +24,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
@@ -67,8 +69,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
- // The signature of the acquire infeed buffer function is:
- //
- // (void*)(int32 length);
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
- i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
+ i8_ptr_type,
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* acquire_func;
@@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
}
acquire_func->setCallingConv(llvm::CallingConv::C);
- // The signature of the release infeed buffer function is:
- //
- // (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
+ b_.getVoidTy(),
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
+ /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer =
- Call(acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer = Call(
+ acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
@@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
/*SrcAlign=*/1, length_32);
}
- Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
- b_.getInt32(shape_length)});
+ Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
return Status::OK();
}
@@ -495,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- // TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not implemented on CPU.");
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
+ if (values != nullptr) {
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
+ }
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto keys_destination_address =
+ EmitBufferPointer(keys_destination, keys->shape());
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+ llvm::Value* values_destination_address = nullptr;
+
+ // The sort is implemented in-place, therefore we first copy the operand
+ // buffer to the output buffer if they are not the same.
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(keys);
+ int64 keys_size = ByteSizeOf(keys->shape());
+ MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, keys_size);
+ }
+ if (values != nullptr) {
+ values_destination_address =
+ EmitBufferPointer(values_destination, values->shape());
+ if (values_destination != GetAllocationSlice(*values)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(values);
+ int64 values_size = ByteSizeOf(values->shape());
+ MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, values_size);
+ }
+ }
+
+ // Normalize the shape and the dimension to sort.
+ Shape normalized_keys_shape =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ keys->shape());
+ int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+ keys->shape().layout())[sort->dimensions(0)];
+
+ int64 sort_dimension_elements =
+ normalized_keys_shape.dimensions(physical_dimension_to_sort);
+ int64 higher_dimensions = 1;
+ for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+ higher_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+ int64 lower_dimensions = 1;
+ for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+ i > physical_dimension_to_sort; --i) {
+ lower_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+
+ PrimitiveType keys_type = keys->shape().element_type();
+ const char* fn_name = nullptr;
+ llvm::Type* keys_native_type = nullptr;
+ switch (keys_type) {
+ case PRED:
+ fn_name = runtime::kKeyValueSortPREDSymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S8:
+ fn_name = runtime::kKeyValueSortS8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case U8:
+ fn_name = runtime::kKeyValueSortU8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S16:
+ fn_name = runtime::kKeyValueSortS16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case U16:
+ fn_name = runtime::kKeyValueSortU16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case F16:
+ fn_name = runtime::kKeyValueSortF16SymbolName;
+ keys_native_type = b_.getHalfTy()->getPointerTo();
+ break;
+ case S32:
+ fn_name = runtime::kKeyValueSortS32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case U32:
+ fn_name = runtime::kKeyValueSortU32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case F32:
+ fn_name = runtime::kKeyValueSortF32SymbolName;
+ keys_native_type = b_.getFloatTy()->getPointerTo();
+ break;
+ case S64:
+ fn_name = runtime::kKeyValueSortS64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case U64:
+ fn_name = runtime::kKeyValueSortU64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case F64:
+ fn_name = runtime::kKeyValueSortF64SymbolName;
+ keys_native_type = b_.getDoubleTy()->getPointerTo();
+ break;
+ default:
+ return Unimplemented(
+ "Element type %s not supported in the Sort op on CPU.",
+ PrimitiveType_Name(keys_type));
+ }
+
+ llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+ b_.getVoidTy(),
+ {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+ b_.getInt8PtrTy(), b_.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto* key_value_sort_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, key_value_sort_type));
+ key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+ key_value_sort_func->setDoesNotThrow();
+ key_value_sort_func->setOnlyAccessesArgMemory();
+ Call(key_value_sort_func,
+ {PointerCast(keys_destination_address, keys_native_type),
+ b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+ b_.getInt64(lower_dimensions),
+ values != nullptr
+ ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+ : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+ values->shape().element_type())
+ : 0)});
+
+ if (values != nullptr) {
+ llvm_ir::EmitTuple(GetIrArrayFor(sort),
+ {keys_destination_address, values_destination_address},
+ &b_, module_);
+ }
+ return Status::OK();
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
@@ -547,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ input_index[i] = NSWSub(
+ NSWAdd(strided_index,
+ NSWMul(window_index[i],
+ b_.getInt64(window.dimensions(i).window_dilation()))),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())),
+ b_.getInt64(0));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = dilation_condition;
+ } else {
+ in_bounds_condition = And(in_bounds_condition, dilation_condition);
+ }
+
+ // Apply base dilation to the index.
+ input_index[i] =
+ SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
@@ -587,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32, F16}));
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(reduce_window->window())) {
- return Unimplemented(
- "Dilation for ReduceWindow is not implemented on CPU.");
- }
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
@@ -1257,10 +1409,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
//
// So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
// [0->0, 3->1].
- gtl::FlatMap<int64, int64> unreduced_dim_map;
+ absl::flat_hash_map<int64, int64> unreduced_dim_map;
- gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
- reduce.dimensions().end());
+ absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
+ reduce.dimensions().end());
const Shape& operand_shape = reduce.operand(0)->shape();
const Shape& result_shape = reduce.shape();
@@ -1836,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
//
// * Implement the memcpy within the innermost loop.
- gtl::FlatSet<int64> inner_dims;
+ absl::flat_hash_set<int64> inner_dims;
for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
break;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df99464ba..586f27b104 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/Triple.h"
@@ -47,7 +48,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+ return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+ }
+
private:
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
@@ -421,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Maps the buffer allocation slices for the parameters to the computation
// being compiled to their parameter numbers. Only relevant for thread local
// computations.
- tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
+ absl::flat_hash_map<BufferAllocation::Index, int64>
computation_parameter_allocations_;
// Maps HLO instructions to their index into the profile counter array.
@@ -561,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
}
};
- tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*,
- LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
+ absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
+ LiteralPtrEqualityFunctor>
emitted_literals_;
- tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
+ absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
std::vector<const HloComputation*> thread_local_computations_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09ec0..ede7f433ca 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+ opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index a99cd99c14..3822d5300e 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -60,7 +60,7 @@ class ParallelTaskAssignment {
// own embedded computation, which is compiled as a parallel compute function,
// and which is invoked from a kCall instruction that is lowered in codegen to
// a runtime parallel fork/join call.
-class ParallelTaskAssigner : public HloPassInterface {
+class ParallelTaskAssigner : public HloModulePass {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000000..e0e7deb98e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+ bool lhs_is_negative = std::signbit(lhs);
+ bool rhs_is_negative = std::signbit(rhs);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(lhs);
+ bool rhs_nan = std::isnan(rhs);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
+ }
+ if (lhs != rhs) {
+ return lhs < rhs;
+ }
+ return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<double, int64>& lhs,
+ const std::pair<double, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<float, int64>& lhs,
+ const std::pair<float, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+ int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<Eigen::half, int64>& lhs,
+ const std::pair<Eigen::half, int64>& rhs) -> bool {
+ return LessThan(
+ Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+ Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+ });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ // High-level idea of the iteration/sorting logic:
+ // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+ // dimension to sort, c is the product of the more minor dimensions (set to 1
+ // if b is the most minor dimension), and a is the product of the more major
+ // dimensions (set to 1 if b is the most major dimension). There are a * c
+ // many rows that we need to sort. We iterate through these, calculate a
+ // 'base_offset' value which points to the first element in that row, and add
+ // i * c for accessing the 'i'-th element in that row.
+
+ int64 sort_dimension_elements = b;
+ int64 num_iteration_elements = a * c;
+ int64 sort_dimension_offset = c;
+
+ std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+ new std::pair<KeyType, int64>[sort_dimension_elements]);
+ std::unique_ptr<std::string[]> reordered_values(
+ new std::string[sort_dimension_elements]);
+ for (int64 index = 0; index < num_iteration_elements; ++index) {
+ // 'index' can be split into two values which index into the 'c' dimension
+ // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+ // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+ // calculating the base offset, we need to multiply the index into the 'a'
+ // dimension with 'b' * 'c'.
+ // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+ int64 base_offset =
+ index % sort_dimension_offset +
+ (index - index % sort_dimension_offset) * sort_dimension_elements;
+ // TODO(b/26783907): We could define a custom iterator class that references
+ // both arrays. Then we could avoid the intermediate copy. However this
+ // would become more complicated, and it is not clear if the benefit is high
+ // enough.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ row_to_sort[i] =
+ std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+ }
+ KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+ }
+ if (values == nullptr) {
+ continue;
+ }
+
+ // Reorder the values according to the order defined by the keys.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index =
+ (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+
+ reordered_values[i] = std::string(values + memory_index,
+ values_primitive_type_size_in_bytes);
+ }
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index = (base_offset + i * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+ memcpy(values + memory_index, reordered_values[i].c_str(),
+ values_primitive_type_size_in_bytes);
+ }
+ }
+}
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+ int8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+ uint8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+ int16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+ uint16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+ int32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+ uint32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+ int64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+ uint64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000000..28e35e82c1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+ tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+ tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+ tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+ tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+ tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+ tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+ tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+ tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064647..9ec0c8f657 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
index a0cd8ee2d2..5cdac203af 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace cpu {
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
index 8b00ae9e47..a383b4a4a0 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
+#include "absl/container/flat_hash_map.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace cpu {
@@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
// This is mutated from within `GetTargetTransformInfoFor` which is
// semantically a getter (and thus `const`); and is therefore declared
// mutable. Making this mutable is okay because it has cache semantics.
- mutable tensorflow::gtl::FlatMap<const llvm::Function*,
- llvm::TargetTransformInfo>
+ mutable absl::flat_hash_map<const llvm::Function*, llvm::TargetTransformInfo>
target_transform_info_cache_;
llvm::TargetMachine* target_machine_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206eee7..4b129c95d4 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cpu_key_value_sort_test",
+ srcs = ["cpu_key_value_sort_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000000..3934c03a04
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+ const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+ a = f32[10] parameter(0)
+
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/true);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 7af51db55a..b35fd9dad8 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) {
CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]]
CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32
CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48
- CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]}
+ CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]}
CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]}
CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
index 8fe65f488a..cc38b81455 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
@@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) {
auto shape = ShapeUtil::MakeShape(U8, {length});
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- length, bytes.data(), bytes.size());
- __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
- bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
// Performs the acquire/release sequence on the outfeed, as the generated CPU
@@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) {
void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- length, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- length, buffer, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
TEST_F(InfeedManagerTest, SingleThreadedSequential) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
xfeed->infeed()->EnqueueBuffersAtomically({b});
@@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
ProcessNextBuffer(a->length());
@@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TEST_F(InfeedManagerTest, MultiThreaded) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
const int32 length = 64;
@@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) {
TEST_F(InfeedManagerTest, OutfeedWrongShape) {
TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->outfeed()->EnqueueBuffersAtomically({b});
ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc
index d124f74d19..661539cccb 100644
--- a/tensorflow/compiler/xla/service/defuser.cc
+++ b/tensorflow/compiler/xla/service/defuser.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) {
fusion_instruction->fused_instructions_computation();
// A map from fused instruction to its defused clone.
- tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*>
defused_instructions;
// Initialize map to contain the fusion instruction parameters mapping
// to the operands of the fusion instruction.
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index c326beb899..aaa41fc4fe 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which replaces all fusion instructions with the equivalent un-fused
// instructions.
-class Defuser : public HloPassInterface {
+class Defuser : public HloModulePass {
public:
Defuser() {}
~Defuser() override {}
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index ba2a674d9a..b3549acfc2 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -24,7 +24,7 @@ namespace xla {
namespace {
// Pass which strips control dependencies from all instructions in the module.
-class ControlDepRemover : public HloPassInterface {
+class ControlDepRemover : public HloModulePass {
public:
ControlDepRemover() = default;
absl::string_view name() const override { return "control-dep-remover"; }
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index 7be70add2f..46dcc3a438 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -30,7 +30,7 @@ namespace xla {
//
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
// and BFloat16MixedPrecisionRemoval.
-class Despecializer : public HloPassInterface {
+class Despecializer : public HloModulePass {
public:
Despecializer();
absl::string_view name() const override { return "despecializer"; }
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 5761573791..68d01d75a2 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index fc38e31700..40e7a3b4c2 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -23,7 +23,7 @@ namespace xla {
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
-class DotDecomposer : public HloPassInterface {
+class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e071d8..515267edd7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
+
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
- }
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
- SetToFirstInsertPoint(if_data.after_block, b_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ }
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index 3cccec9862..986970f886 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -26,7 +26,7 @@ namespace xla {
// Flattening associates each call site with a unique computation (for
// sequential calling contexts) This simplifies buffer assignment and
// points-to analysis (see b/36865746 for details).
-class FlattenCallGraph : public HloPassInterface {
+class FlattenCallGraph : public HloModulePass {
public:
absl::string_view name() const override { return "flatten-call-graph"; }
diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h
new file mode 100644
index 0000000000..1208a7dda8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/fusion_queue.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index 7bd9ea5984..2b39359aae 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
-class GatherExpander : public HloPassInterface {
+class GatherExpander : public HloModulePass {
public:
absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 64b9683628..350fd32537 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -68,9 +68,7 @@ cc_library(
# srcs = [
# "partition_assignment_test.cc",
# ],
-# tags = [
-# "requires-gpu-sm35",
-# ],
+# tags = tf_cuda_tests_tags(),
# deps = [
# ":partition_assignment",
# "//tensorflow/core:stream_executor_no_cuda",
@@ -93,6 +91,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -359,6 +358,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -373,7 +373,6 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
deps = [
":backend_configs",
- ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -405,6 +404,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
],
)
@@ -414,6 +414,8 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -422,8 +424,10 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -432,6 +436,7 @@ cc_library(
srcs = ["cudnn_convolution_rewriter.cc"],
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
+ ":backend_configs",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
@@ -472,6 +477,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
"//tensorflow/compiler/xla/service:pattern_matcher",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -504,6 +510,7 @@ cc_library(
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -537,6 +544,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -596,14 +604,11 @@ cc_library(
hdrs = ["pad_for_tensor_cores.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/compiler/xla/service:shape_inference",
],
)
@@ -656,6 +661,7 @@ cc_library(
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
+ ":cudnn_fused_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@@ -713,6 +719,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -774,7 +781,6 @@ cc_library(
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
deps = [
- ":gpu_options",
":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
@@ -783,6 +789,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -875,16 +882,6 @@ cc_library(
)
cc_library(
- name = "gpu_options",
- srcs = ["gpu_options.cc"],
- hdrs = ["gpu_options.h"],
- deps = [
- "//tensorflow/compiler/xla/service:hlo_module_config",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-cc_library(
name = "stream_executor_util",
srcs = ["stream_executor_util.cc"],
hdrs = ["stream_executor_util.h"],
@@ -967,3 +964,19 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "cudnn_fused_convolution_rewriter",
+ srcs = ["cudnn_fused_convolution_rewriter.cc"],
+ hdrs = ["cudnn_fused_convolution_rewriter.h"],
+ deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 640c6392b8..78e14d860e 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -24,4 +24,18 @@ message CudnnConvBackendConfig {
// true, cudnn may choose not to use tensor cores, e.g. because the GPU or
// selected algorithm doesn't support it.
bool tensor_ops_enabled = 2;
+
+ // The scaling factor multiplied with the convolution result.
+ double conv_result_scale = 4;
+
+ // Below are the fields related to cuDNN's fused convolution. Refer to
+ // CudnnConvParams for their meanings.
+
+ // The requested activation (e.g. relu) after the convolution. It is with type
+ // stream_executor::dnn::ActivationMode.
+ int64 activation_mode = 3;
+
+ // The scaling factor multiplied with the side input. If no side input buffer
+ // is provided, this field must be 0.
+ double side_input_scale = 5;
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1d63..4effea637d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,37 +29,38 @@ limitations under the License.
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- CudnnConvParams params;
+ std::vector<se::DeviceMemoryBase> operand_se_buffers;
+ for (const auto& buffer : operand_buffers_) {
+ operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
+ }
+
+ se::DeviceMemoryBase result_buffer =
+ buffer_allocations.GetDeviceAddress(result_buffer_);
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_,
+ absl::MakeSpan(operand_se_buffers),
+ result_buffer, scratch, stream));
- // Figure out which of output/input/filter is the result produced by
- // this op, and write the result tuple.
- void* result_ptr = [&] {
- switch (params.kind) {
- case CudnnConvKind::kForward:
- return params.output_buf.opaque();
- case CudnnConvKind::kBackwardInput:
- return params.input_buf.opaque();
- case CudnnConvKind::kBackwardFilter:
- return params.filter_buf.opaque();
- }
- }();
- void* ptrs[] = {result_ptr, scratch.opaque()};
+ void* ptrs[] = {result_buffer.opaque(), scratch.opaque()};
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91fba..f53bc54198 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk {
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index 6e2e330edd..c3f58508dd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -52,7 +52,7 @@ namespace gpu {
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
// BatchNormRewriter.
-class CudnnBatchNormRewriter : public HloPassInterface {
+class CudnnBatchNormRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index f528e62b17..590c0a7d54 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -76,54 +76,24 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
return se::DeviceMemory<uint8>(buffer_addr);
}
-// Determines whether we can safely perform a winograd non-fused convolution for
-// the given input and output shapes. This works around b/68264959, an integer
-// overflow in cuDNNv5 and cuDNNv6.
-bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
- const Shape& output_shape,
- const ConvolutionDimensionNumbers& dnums,
- se::StreamExecutor* stream_exec) {
- // Skip this check for cudnn7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
- if (version.ok() && version.ValueOrDie().major_version() >= 7) {
- return true;
- }
-
- int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
- int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
- int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
- int64 in_cols =
- dnums.input_spatial_dimensions_size() == 1
- ? 1
- : input_shape.dimensions(dnums.input_spatial_dimensions(1));
- int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension());
-
- int64 total_size = CeilOfRatio(batch, int64{16}) *
- std::max(in_depths, out_depths) * in_cols * in_rows *
- sizeof(float);
-
- const int64 threshold = 1L << 31;
- return total_size < threshold;
-}
-
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
- bool with_winograd_nonfused,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
+ bool succ = false;
switch (kind) {
case CudnnConvKind::kBackwardFilter:
- CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ =
+ stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kBackwardInput:
- CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kForward:
- CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
- &algorithms));
+ case CudnnConvKind::kForwardActivation:
+ succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
break;
}
+ DCHECK(succ);
return algorithms;
}
@@ -175,21 +145,13 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-StatusOr<std::tuple<int64, bool, int64>>
+StatusOr<CudnnConvolutionAlgorithmPicker::AutotuneResult>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- const HloCustomCallInstruction* instr) {
- CudnnConvParams params;
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
-
- const Shape& input_shape = *params.input_shape;
- const Shape& filter_shape = *params.filter_shape;
- const Shape& output_shape = *params.output_shape;
-
- CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
- CHECK_EQ(input_shape.element_type(), output_shape.element_type());
+ HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
// with some work on the HLO routines.
- const bool cross_check_enabled = input_shape.element_type() == xla::F16;
+ const bool cross_check_enabled =
+ instr->shape().tuple_shapes(0).element_type() == xla::F16;
// Don't run this function concurrently on the same GPU.
//
@@ -257,51 +219,43 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(params.input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(params.filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(params.output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- initialize_buffer(params.input_buf);
- initialize_buffer(params.filter_buf);
- initialize_buffer(params.output_buf);
-
- DeviceMemoryBase* result_buf = [&] {
- switch (params.kind) {
- case CudnnConvKind::kBackwardFilter:
- return &params.filter_buf;
- case CudnnConvKind::kBackwardInput:
- return &params.input_buf;
- case CudnnConvKind::kForward:
- return &params.output_buf;
- }
- }();
+ std::vector<se::DeviceMemoryBase> operand_buffers;
+ for (const auto* operand : instr->operands()) {
+ TF_ASSIGN_OR_RETURN(auto buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(operand->shape())));
+ initialize_buffer(buffer);
+ operand_buffers.push_back(buffer);
+ }
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+ initialize_buffer(result_buffer);
- const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
optional<F16BufferComparator> comparator;
// Use the first algorithm that's supported as reference. There isn't a
// particular reason to use it, as any algorithm sufficies. It doesn't make
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
- for (const AlgorithmDesc& alg :
- GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
+ for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- params.algorithm = AlgorithmConfig(alg);
- bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
- &profile_result)
+ backend_config.set_algorithm(alg.algo_id());
+ backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled());
+ TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config));
+ bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers),
+ result_buffer, &scratch_allocator,
+ &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -312,7 +266,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.xla_gpu_crash_on_verification_failures();
if (comparator.has_value()) {
StatusOr<bool> result = comparator->CompareEqual(
- se::DeviceMemory<Eigen::half>(*result_buf));
+ se::DeviceMemory<Eigen::half>(result_buffer));
if (!result.ok()) {
LOG(ERROR) << "Unable to compare "
<< AlgorithmToString(*first_algorithm) << " against "
@@ -330,7 +284,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
}
} else if (cross_check_enabled) {
auto comp = F16BufferComparator::Create(
- se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator,
+ se::DeviceMemory<Eigen::half>(result_buffer), compiler_, allocator,
&stream);
if (comp.ok()) {
comparator.emplace(comp.ConsumeValueOrDie());
@@ -362,9 +316,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< AlgorithmToString(best_result.algorithm()) << ", takes "
<< best_result.elapsed_time_in_ms() << "ms, and uses "
<< best_result_bytes_used << "B of scratch memory.";
- return std::make_tuple(best_result.algorithm().algo_id(),
- best_result.algorithm().tensor_ops_enabled(),
- best_result_bytes_used);
+ return AutotuneResult{best_result.algorithm().algo_id(),
+ best_result.algorithm().tensor_ops_enabled(),
+ best_result_bytes_used,
+ absl::Milliseconds(best_result.elapsed_time_in_ms())};
}
return InternalError(
@@ -377,40 +332,34 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ StatusOr<AutotuneResult> best_algo_or =
PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
-
- if (!alg_scratch_and_tc.ok()) {
- LOG(ERROR) << alg_scratch_and_tc.status();
+ if (!best_algo_or.ok()) {
+ LOG(ERROR) << best_algo_or.status();
return false;
}
- int64 algorithm;
- bool tensor_ops_enabled;
- int64 scratch_bytes;
-
- std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
- alg_scratch_and_tc.ConsumeValueOrDie();
-
- VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
- << NumBytesToString(scratch_bytes)
+ auto best_algo = std::move(best_algo_or).ValueOrDie();
+ VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm
+ << " and " << NumBytesToString(best_algo.scratch_bytes)
<< " of scratch memory: " << instr->ToString()
- << " tensor_ops_enabled: " << tensor_ops_enabled;
+ << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled;
// Replace instr with a new CustomCall which has the correct algorithm, and
// whose output shape has the appropriate amount of scratch memory.
HloComputation* computation = instr->parent();
- Shape new_call_shape =
- ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
- ShapeUtil::MakeShape(U8, {scratch_bytes})});
+ Shape new_call_shape = ShapeUtil::MakeTupleShape(
+ {instr->shape().tuple_shapes(0),
+ ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})});
- CudnnConvBackendConfig backend_config;
- backend_config.set_algorithm(algorithm);
- backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
+ backend_config.set_algorithm(best_algo.algorithm);
+ backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
- instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
- instr->mutable_operand(1)}));
+ instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f79b113f8f..136c32210a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -30,7 +31,7 @@ namespace gpu {
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
-class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+class CudnnConvolutionAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,
@@ -47,10 +48,16 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
private:
+ struct AutotuneResult {
+ int64 algorithm;
+ bool tensor_ops_enabled;
+ int64 scratch_bytes;
+ absl::Duration runtime;
+ };
+
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- const HloCustomCallInstruction* instr);
+ StatusOr<AutotuneResult> PickBestAlgorithm(HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 228379a248..ef29237301 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -35,6 +36,32 @@ namespace gpu {
namespace {
+HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
+ HloComputation* computation = lhs->parent();
+
+ // This call returns a tuple of (conv_result, scratch_memory), where
+ // conv_result is the actual result of the convolution, and scratch_memory is
+ // temporary memory used by cudnn.
+ //
+ // At the moment, we don't know how much scratch memory this conv is going to
+ // use, so we put u8[0] in this place. Later on another pass will choose
+ // which conv algorithm to use, and at that point we'll modify the shape of
+ // this second tuple element.
+ Shape call_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+ custom_call->set_window(window);
+ custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
+ return custom_call;
+}
+
bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
const ConvolutionDimensionNumbers& dnums =
conv->convolution_dimension_numbers();
@@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) {
return std::make_tuple(true, new_window, dnums, rhs);
}
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
- return CreateCudnnConvBackwardFilter(
- conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
+ conv->mutable_operand(0), conv->mutable_operand(1),
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- return CreateCudnnConvBackwardInput(conv->shape(),
- conv->mutable_operand(0), rhs, window,
- dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+ conv->mutable_operand(0), rhs, window, dnums,
+ conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
- return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
- conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers(),
- conv->feature_group_count());
+ return CreateCudnnConv(
+ kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0),
+ conv->mutable_operand(1), conv->window(),
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
}
return nullptr;
@@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
return false;
}
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(GetDefaultBackendConfig()));
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index fbe7e98494..8d7c6fdab5 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -24,7 +24,7 @@ namespace gpu {
// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
-class CudnnConvolutionRewriter : public HloPassInterface {
+class CudnnConvolutionRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-convolution-rewriter";
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 2a86ac265e..89dd1bb272 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -37,6 +39,42 @@ using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
+struct CudnnConvParams {
+ // Here are the fields related to cuDNN's fused convolution. The result thus
+ // is defined as:
+ // activation(conv_result_scale * conv(x, w) +
+ // side_input_scale * side_input + broadcast(bias))
+ //
+ // The most common fused conv is conv forward + relu/identity, for example.
+ //
+ // bias_buf is a single-dimensional array, with the length equal to the number
+ // of output features. It'll be broadcasted to the output shape in order to be
+ // added to the final results.
+ //
+ // side_input_buf, if valid, must have the same shape as the output buffer.
+ struct FusionParams {
+ se::dnn::ActivationMode mode;
+ double side_input_scale;
+ se::DeviceMemoryBase bias_buf;
+ se::DeviceMemoryBase side_input_buf; // nullable
+ };
+
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+ double conv_result_scale;
+
+ absl::optional<FusionParams> fusion;
+};
+
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
@@ -92,9 +130,9 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
- VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
- VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
- VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
+ VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
+ VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
+ VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
@@ -186,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
switch (kind) {
case CudnnConvKind::kForward:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveWithAlgorithm(
input_descriptor, input_buf, filter_descriptor, filter_buf,
convolution_descriptor, output_descriptor, &output_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardInput:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardDataWithAlgorithm(
filter_descriptor, filter_buf, output_descriptor, output_buf,
convolution_descriptor, input_descriptor, &input_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardFilter:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardFilterWithAlgorithm(
input_descriptor, input_buf, output_descriptor, output_buf,
convolution_descriptor, filter_descriptor, &filter_buf,
scratch_allocator, algorithm, profile_result);
break;
+ case CudnnConvKind::kForwardActivation: {
+ BatchDescriptor bias_desc;
+ bias_desc.set_count(1)
+ .set_height(1)
+ .set_width(1)
+ .set_feature_map_count(
+ output_shape.dimensions(dnums.output_feature_dimension()))
+ .set_layout(output_dl);
+
+ se::DeviceMemory<T> side_input(params.fusion->side_input_buf);
+ // If there is no side input, use output as the side input.
+ if (side_input.is_null()) {
+ if (params.fusion->side_input_scale != 0) {
+ return InternalError(
+ "Side input scale is not 0, yet no side input buffer is "
+ "provided");
+ }
+ // Since side-input scale is 0, the values in the side input don't
+ // matter. The simplest thing to do would be to pass in a null buffer
+ // for the side input, but cudnn doesn't allow this. cudnn does promise
+ // that if side-input-scale is 0 the side input won't be read, so we
+ // just pass in the output buffer, since it's handy and has the correct
+ // size.
+ side_input = output_buf;
+ }
+
+ stream->ThenFusedConvolveWithAlgorithm(
+ input_descriptor, input_buf, params.conv_result_scale,
+ filter_descriptor, filter_buf, convolution_descriptor, side_input,
+ params.fusion->side_input_scale, bias_desc,
+ DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode,
+ output_descriptor, &output_buf, scratch_allocator, algorithm,
+ profile_result);
+ break;
+ }
}
if (!stream->ok()) {
@@ -214,32 +302,104 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
return Status::OK();
}
-} // anonymous namespace
+// Returns the cudnn convolution parameters generated from conv, which must be a
+// custom-call to a cudnn convolution.
+StatusOr<CudnnConvParams> GetCudnnConvParams(
+ const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer) {
+ CudnnConvParams params;
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ const auto& target = conv->custom_call_target();
+ const auto& lhs_shape = conv->operand(0)->shape();
+ const auto& rhs_shape = conv->operand(1)->shape();
+ const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+
+ params.window = &conv->window();
+ params.dnums = &conv->convolution_dimension_numbers();
+ params.feature_group_count = conv->feature_group_count();
+ params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+ params.conv_result_scale = backend_config.conv_result_scale();
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params.kind = CudnnConvKind::kForward;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params.kind = CudnnConvKind::kBackwardInput;
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params.kind = CudnnConvKind::kBackwardFilter;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
+ } else {
+ return InternalError("Unexpected custom call target: %s", target);
}
+ return params;
}
-Status RunCudnnConvolution(CudnnConvParams params,
+} // anonymous namespace
+
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(params, &scratch_allocator, stream,
- profile_result);
+ return RunCudnnConvolution(conv, operand_buffers, result_buffer,
+ &scratch_allocator, stream, profile_result);
}
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = params.output_shape->element_type();
+ TF_ASSIGN_OR_RETURN(CudnnConvParams params,
+ GetCudnnConvParams(conv, operand_buffers, result_buffer));
+
+ PrimitiveType output_primitive_type =
+ conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 381aa37a1b..61aec1cecc 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,52 +30,8 @@ namespace gpu {
// This file contains low-level routines for running cudnn convolutions.
-// Different types of convolutions supported by cudnn.
-//
-// A way to think about these is that a convolution is defined by three arrays
-// -- the "input", the "filter", and the "output" -- and given any two of these,
-// we can compute the third. For example, a backward-input convolution takes as
-// input a filter and an "output" and produces an "input" such that if one were
-// to do a forward convolution of "input" using filter, the result would be
-// something with the same shape as "output".
-//
-// This way of thinking is not correct if you look at the values produced. For
-// example, a backward-input convolution is not actually the mathematical
-// inverse of a forward convolution. But it's right as far as the shapes and
-// "connectivity" (i.e. which elements of the input affect which elements of
-// the output) are concerned.
-enum class CudnnConvKind {
- kForward, // input + filter => output
- kBackwardInput, // filter + output => input
- kBackwardFilter, // input + output => filter
-};
-
-struct CudnnConvParams {
- CudnnConvKind kind;
- const Shape* input_shape;
- const Shape* filter_shape;
- const Shape* output_shape;
- se::DeviceMemoryBase input_buf;
- se::DeviceMemoryBase filter_buf;
- se::DeviceMemoryBase output_buf;
- const Window* window;
- const ConvolutionDimensionNumbers* dnums;
- int64 feature_group_count;
- se::dnn::AlgorithmConfig algorithm;
-};
-
-// Converts a CudnnConvKind value to a string.
-string CudnnConvKindToString(CudnnConvKind kind);
-
// Calls into cudnn to run the specified convolution.
//
-// Note that depending on the value of CudnnConvKind, the result of this call
-// may be written into input_buf, filter_buf, or output_buf!
-//
-// At the moment convolution with half data type is implemented with cudnn
-// PSEUDO_HALF configuration, that is, the input values are half and the
-// internal computation type is float.
-//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
// theory the second one shouldn't be necessary -- users of this function could
@@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
new file mode 100644
index 0000000000..3761c19cfc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -0,0 +1,278 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Describes a matched pattern:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
+
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ // The pattern we want to match:
+ // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+ //
+ // With its variants involving commute/reassociation of adds, multiplies, and
+ // max, and omission of alpha1, side_input, alpha2, or bias.
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(match::ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
+ }
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
+
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
+
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
+ //
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
+ //
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
+ } else {
+ return absl::nullopt;
+ }
+ }
+ }
+ }
+
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
+ }
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
+}
+
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+ PrimitiveType element_type = conv->operand(0)->shape().element_type();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+
+ int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(element_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+} // namespace
+
+StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
index 498d4a9495..bd12aadded 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
@@ -13,21 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-
-// Helper functions for querying options that are specific to the GPU backend.
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
-// Returns true if we should use heuristics to assign convolution layouts, as
-// opposed to always assigning NCHW.
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
+class CudnnFusedConvolutionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
} // namespace gpu
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c1aaa4bf04..6dcdaf1cfe 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
const Window& window = hlo->window();
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
- return Unimplemented(
- "Dilation for reduce-window not implemented on GPU. "
- "See b/31410564.");
- }
-
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
@@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(
+ NSWAdd(stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(
+ window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
input_index[i] =
- NSWSub(NSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ SDiv(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7e3f5775b8..f19996edfe 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -32,7 +32,7 @@ namespace gpu {
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
-class FusionMerger : public HloPassInterface {
+class FusionMerger : public HloModulePass {
public:
absl::string_view name() const override { return "fusion merger"; }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 75f414e47f..e2ab00ce41 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -27,22 +28,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace gpu {
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
- HloInstruction* hlo) {
- HloInstruction*& copy = hlo_to_copy_map_[hlo];
- if (copy == nullptr) {
- TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
- }
- return copy;
-}
-
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 8ffae18fe8..4c7e38ffeb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -25,20 +25,11 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
-
- protected:
- // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
- // duplicate copies.
- StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
- // A map containing all copies inserted to materialize operands of library
- // calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 31a9f9b1be..5742632782 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
@@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
}
module_spec.AddCudaPtxInMemory(ptx().c_str());
- tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals;
+ absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
se::ModuleHandle module_handle;
executor->LoadModule(module_spec, &module_handle);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 38b0f8f15b..0e276282e4 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -101,7 +101,7 @@ class GpuExecutable : public Executable {
const PointsToSet& GetRootPointsToSet() const;
using BufferAllocToDeviceMemoryMap =
- tensorflow::gtl::FlatMap<BufferAllocation::Index, se::DeviceMemoryBase>;
+ absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;
// Loads the PTX or CUBIN for this executable into `executor` and resolves the
// globals corresponding to constant buffers. Returns a map mapping buffer
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index bbb3340760..9c64b4d10c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class GpuHloSupportChecker : public HloPassInterface {
+class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index d033faee8d..1ffe855750 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -18,11 +18,12 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -90,45 +91,46 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// operands and the output shape. Depending on the underlying algorithm, one of
// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints) {
- CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
- Shape input_shape;
- Shape filter_shape;
- Shape output_shape;
- const auto& target = instr->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->shape().tuple_shapes(0);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_shape = instr->shape().tuple_shapes(0);
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->operand(0)->shape();
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->shape().tuple_shapes(0);
- output_shape = instr->operand(1)->shape();
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+ Shape lhs_shape = instr->operand(0)->shape();
+ Shape rhs_shape = instr->operand(1)->shape();
+ Shape result_shape = instr->shape().tuple_shapes(0);
+
+ Shape* input_shape;
+ Shape* filter_shape;
+ Shape* output_shape;
+
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ input_shape = &lhs_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &result_shape;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ input_shape = &result_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &lhs_shape;
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ input_shape = &lhs_shape;
+ filter_shape = &result_shape;
+ output_shape = &rhs_shape;
+ break;
}
{
DataLayout input;
FilterLayout filter;
DataLayout output;
- if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
- std::tie(input, filter, output) =
- HeuristicLayoutAssignment(instr, stream_executor_);
- } else {
- input = DataLayout::kBatchDepthYX;
- filter = FilterLayout::kOutputInputYX;
- output = DataLayout::kBatchDepthYX;
- }
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, stream_executor_);
TF_ASSIGN_OR_RETURN(
- std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
- *output_shape.mutable_layout()),
+ std::tie(*input_shape->mutable_layout(),
+ *filter_shape->mutable_layout(),
+ *output_shape->mutable_layout()),
StreamExecutorConvLayoutsToXlaLayouts(
instr->convolution_dimension_numbers(), input, filter, output));
}
@@ -141,24 +143,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
instr, /*index=*/{0}));
// Set layouts of the instructions' shapes.
- if (target == kCudnnConvForwardCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(output_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(input_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf));
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
+ // instr->operand(2), if exists, is the bias buffer. There is no need to
+ // assign layout to it, as it has only one dimension.
+
+ // instr->opernad(3), if exists, is the side input buffer.
+ if (instr->operand_count() == 4) {
+ if (kind != CudnnConvKind::kForwardActivation) {
+ return InternalError(
+ "Invalid convolution. Conv has a side input, but kind is not fused "
+ "conv forward: %s",
+ instr->ToString());
+ }
+ // The side input layout must match the output layout.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3));
}
return Status::OK();
}
@@ -173,8 +174,8 @@ Status GpuLayoutAssignment::AddBackendConstraints(
++iterator) {
HloInstruction* instruction = *iterator;
if (IsCustomCallToDnnConvolution(*instruction)) {
- TF_RETURN_IF_ERROR(
- AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
+ TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+ Cast<HloCustomCallInstruction>(instruction), constraints));
}
// For batched dot we require the default layout.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index ce24af1cf8..4ba7989e9c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -29,8 +30,11 @@ namespace gpu {
class GpuLayoutAssignment : public LayoutAssignment {
public:
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
se::StreamExecutor* stream_executor)
- : LayoutAssignment(entry_computation_layout),
+ : LayoutAssignment(entry_computation_layout,
+ std::move(instruction_can_change_layout_func)),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
@@ -47,7 +51,7 @@ class GpuLayoutAssignment : public LayoutAssignment {
private:
Status AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints);
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints);
se::StreamExecutor* stream_executor_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index fbc8ddf599..04681cfcec 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
ShapeLayout(result_shape_with_layout);
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
@@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- GpuLayoutAssignment layout_assignment(&computation_layout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape =
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 4d5d8e99f8..b61f038739 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
}
// Compute the precise number of operands to the new fusion.
- tensorflow::gtl::FlatSet<const HloInstruction*> operands(
- a->operands().begin(), a->operands().end());
+ absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
+ a->operands().end());
operands.insert(b->operands().begin(), b->operands().end());
// If there's an edge between `a` and `b`, don't count it: We're fusing that
// producer -> consumer relationship.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 22f43bc08b..ec3d8f9405 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget =
"__cudnn$convBackwardInput";
const char* const kCudnnConvBackwardFilterCallTarget =
"__cudnn$convBackwardFilter";
+const char* const kCudnnConvBiasActivationForwardCallTarget =
+ "__cudnn$convBiasActivationForward";
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
@@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
const auto& target = hlo.custom_call_target();
return target == kCudnnConvForwardCallTarget ||
target == kCudnnConvBackwardInputCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget;
+ target == kCudnnConvBackwardFilterCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
@@ -145,59 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(const char* call_target,
- const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- HloComputation* computation = lhs->parent();
-
- // This call returns a tuple of (conv_result, scratch_memory), where
- // conv_result is the actual result of the convolution, and scratch_memory is
- // temporary memory used by cudnn.
- //
- // At the moment, we don't know how much scratch memory this conv is going to
- // use, so we put u8[0] in this place. Later on another pass will choose
- // which conv algorithm to use, and at that point we'll modify the shape of
- // this second tuple element.
- Shape call_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
-
- HloInstruction* custom_call = computation->AddInstruction(
- HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
- custom_call->set_window(window);
- custom_call->set_convolution_dimension_numbers(dnums);
- custom_call->set_feature_group_count(feature_group_count);
- return custom_call;
-}
-
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums, feature_group_count);
-}
-
bool IsReductionToVector(const HloInstruction& reduce) {
if (HloOpcode::kReduce != reduce.opcode()) {
return false;
@@ -288,41 +238,35 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params) {
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
- const auto& target = custom_call->custom_call_target();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
-
- params->window = &custom_call->window();
- params->dnums = &custom_call->convolution_dimension_numbers();
- params->feature_group_count = custom_call->feature_group_count();
- params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
- backend_config.algorithm(), backend_config.tensor_ops_enabled()));
-
+StatusOr<CudnnConvKind> GetCudnnConvKind(
+ const HloCustomCallInstruction* instr) {
+ absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
- params->kind = CudnnConvKind::kForward;
- params->input_shape = &lhs_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &conv_result_shape;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params->kind = CudnnConvKind::kBackwardInput;
- params->input_shape = &conv_result_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &lhs_shape;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params->kind = CudnnConvKind::kBackwardFilter;
- params->input_shape = &lhs_shape;
- params->filter_shape = &conv_result_shape;
- params->output_shape = &rhs_shape;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
+ return CudnnConvKind::kForward;
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ return CudnnConvKind::kBackwardInput;
+ }
+ if (target == kCudnnConvBackwardFilterCallTarget) {
+ return CudnnConvKind::kBackwardFilter;
+ }
+ if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ return CudnnConvKind::kForwardActivation;
+ }
+ return InternalError("Unexpected call target: %s", target);
+}
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ return "forward";
+ case CudnnConvKind::kBackwardFilter:
+ return "backward_filter";
+ case CudnnConvKind::kBackwardInput:
+ return "backward_input";
+ case CudnnConvKind::kForwardActivation:
+ return "forward with activation";
}
- return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 09c455cc1e..a64a616ab1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
-#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -30,6 +29,33 @@ limitations under the License.
namespace xla {
namespace gpu {
+// Different types of convolutions supported by cudnn.
+//
+// A way to think about these is that a convolution is defined by three arrays
+// -- the "input", the "filter", and the "output" -- and given any two of these,
+// we can compute the third. For example, a backward-input convolution takes as
+// input a filter and an "output" and produces an "input" such that if one were
+// to do a forward convolution of "input" using filter, the result would be
+// something with the same shape as "output".
+//
+// This way of thinking is not correct if you look at the values produced. For
+// example, a backward-input convolution is not actually the mathematical
+// inverse of a forward convolution. But it's right as far as the shapes and
+// "connectivity" (i.e. which elements of the input affect which elements of
+// the output) are concerned.
+enum class CudnnConvKind {
+ kForward, // input + filter => output
+ kBackwardInput, // filter + output => input
+ kBackwardFilter, // input + output => filter
+ kForwardActivation, // activation(conv(input, filter) + broadcast(bias) +
+ // (optionally) side_input) => output
+};
+
+StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
+
+// Converts a CudnnConvKind value to a string.
+string CudnnConvKindToString(CudnnConvKind kind);
+
constexpr int64 kWarpSize = 32;
// Returns true if `hlo` will be implemented as a call to BLAS gemm.
@@ -95,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
extern const char* const kCudnnConvForwardCallTarget;
extern const char* const kCudnnConvBackwardInputCallTarget;
extern const char* const kCudnnConvBackwardFilterCallTarget;
+extern const char* const kCudnnConvBiasActivationForwardCallTarget;
// Returns true if `hlo` will be implemented as a call to a cuDNN convolution
// routine.
@@ -104,28 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget;
// kConvolution opcode.
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
-// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv.
-// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If
-// you want just the conv result, you'll need to get-tuple-element the value
-// returned by this function.
-//
-// The created cudnn call will use the default cudnn algorithm and no scratch
-// space.
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
bool ImplementedAsLibraryCall(const HloInstruction& hlo);
@@ -150,11 +155,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
-// Populates params using conv, which must be a custom-call to a cudnn
-// convolution. Does not modify any buffers in the params.
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params);
-
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881026..c792dd2ddb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c21f76f6eb..835924024b 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
HloInstruction* instr2) {
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
for (auto instr : instr1->operands()) {
if (!IsProfitableOperand(instr)) {
continue;
@@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
bool changed = false;
RecomputeReachability();
- tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ absl::flat_hash_set<HloInstruction*> to_fuse;
// Keep a list of the instructions to fuse after making all the fusion
// decisions. We first aggressively add instructions to potential_fusion_list,
// then filter out instructions that will be no longer fusible because of
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index dfdcf1875d..ac6c2c5565 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
+ pipeline.AddPass<CudnnFusedConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -230,14 +232,17 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// a layout-sensitive verifier!
HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout(), stream_exec);
+ hlo_module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout, stream_exec);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
HloPassPipeline pipeline("post-layout_assignment");
- pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ pipeline.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -283,8 +288,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ fusion.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
@@ -296,7 +303,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
HloPassPipeline reduce_pipeline("reduce-precision");
reduce_pipeline.AddInvariantChecker<HloVerifier>(
- /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -322,8 +330,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ pipeline.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -398,11 +408,11 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
"prefers >= 9.2.88). Compilation of XLA kernels below will likely "
"fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas "
"binary is sufficient.";
- } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) {
+ } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) {
LOG(WARNING)
<< "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "."
<< vdot
- << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
+ << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
"miscompile XLA code, leading to incorrect results or "
"invalid-address errors.\n\nYou do not need to update to CUDA "
"9.2.88; cherry-picking the ptxas binary is sufficient.";
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 8e97774750..c4a0b727cd 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/node_hash_map.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler {
tensorflow::condition_variable compilation_done_cv_;
};
- // Don't even think about switching this to FlatMap; iterator stability is
- // critical here.
- std::unordered_map<CompilationCacheKey, CompilationCacheValue,
- CompilationCacheHash, CompilationCacheEq>
+ // Don't even think about switching this to flat_hash_map; iterator stability
+ // is critical here.
+ absl::node_hash_map<CompilationCacheKey, CompilationCacheValue,
+ CompilationCacheHash, CompilationCacheEq>
compilation_cache_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler);
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index b0061fa655..e3869b5c36 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -36,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8;
// there's additional room for speedups. Achieving those speedups without also
// slowing other things down will likely require a more sophisticated heuristic,
// possibly some form of auto-tuning.
-static constexpr double kMaxBytesTouchedIncrease = 1.2;
+//
+// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4"
+// special case inside PadShape won't fire.
+static constexpr double kMaxBytesTouchedIncrease = 1.35;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
- int64 new_dim_to_pad_size =
- RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+
+ // Round dim_to_pad_size up to the next multiple of
+ // kDesiredNumFeaturesFactor.
+ //
+ // Special case: dims of size 3 are rounded up to 4, not
+ // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia),
+ // this helps, but as of writing, it's not supported by anything in the
+ // cudnn docs.
+ int64 new_dim_to_pad_size;
+ if (dim_to_pad_size == 3) {
+ new_dim_to_pad_size = 4;
+ } else {
+ new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ }
+
s.set_dimensions(dim, new_dim_to_pad_size);
}
return s;
@@ -209,7 +227,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
std::vector<HloInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16) {
+ instr->operand(0)->shape().element_type() == F16 &&
+ // TODO(timshen): Disable for fused conv for now. Implement it if it's
+ // needed.
+ Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ kCudnnConvBiasActivationForwardCallTarget) {
convs.push_back(instr);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 11dc56a64f..e592a3774e 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -30,7 +30,7 @@ namespace gpu {
// targeting before running this pass.
//
// TODO(jlebar): Also pad dots.
-class PadForTensorCores : public HloPassInterface {
+class PadForTensorCores : public HloModulePass {
public:
absl::string_view name() const override { return "pad for tensor cores"; }
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 2a6415d0b6..b42a19e3a2 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -30,7 +30,8 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -161,12 +162,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
- Shape old_conv_shape = conv->shape().tuple_shapes(0);
-
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(
- old_conv_shape, new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers(), conv->feature_group_count());
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
+ auto new_conv = conv->parent()->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), operands));
+ new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -242,10 +245,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.
- Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
- HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
- backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ backward_conv->shape(), {padded_input, output}));
+ new_backward_conv->set_window(new_backward_conv_window);
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -308,9 +311,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* output = backward_conv->mutable_operand(0);
HloInstruction* filter = backward_conv->mutable_operand(1);
- HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
- new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv_call =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(
+ {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+ {output, filter}));
+ new_backward_conv_call->set_window(new_backward_conv_window);
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
@@ -380,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
}
for (HloInstruction* instruction : convs) {
const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget) {
changed |= CanonicalizeForwardConvolution(instruction);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
changed |= CanonicalizeBackwardFilterConvolution(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index a622e894ed..25cdf64c4c 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -24,7 +24,7 @@ namespace gpu {
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
-class PadInsertion : public HloPassInterface {
+class PadInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "pad insertion"; }
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index cf9f102d31..375f68a159 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -62,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions(
//
// <num threads per block> * <max blocks per core> = <max threads per core>
- auto threads_per_core = device_desc.threads_per_core_limit();
- auto blocks_per_core = device_desc.blocks_per_core_limit();
- int64 threads_per_block;
- if (threads_per_core != 0 && blocks_per_core != 0) {
- threads_per_block = device_desc.threads_per_core_limit() /
- device_desc.blocks_per_core_limit();
- } else {
+ int64 threads_per_block = device_desc.threads_per_block_limit();
+ if (threads_per_block == 0) {
static std::atomic<int64> log_count{0};
if (log_count.fetch_add(1) < 8) {
LOG(WARNING) << "Attempting to calculate launch dimensions for GPU "
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
index c2df83aaa4..52d38b6f20 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace gpu {
@@ -34,7 +34,7 @@ class StreamAssignment {
private:
int stream_count_ = 1; // At least the main stream.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> hlo_to_stream_number_;
+ absl::flat_hash_map<const HloInstruction*, int> hlo_to_stream_number_;
};
// Assigns GPU streams to instructions in `module`.
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index db4a33dc56..a725533567 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -25,15 +25,17 @@ filegroup(
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "gpu_codegen_test",
testonly = True,
srcs = ["gpu_codegen_test.cc"],
hdrs = ["gpu_codegen_test.h"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -48,9 +50,7 @@ cc_library(
tf_cc_test(
name = "gpu_copy_test",
srcs = ["gpu_copy_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -67,9 +67,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ftz_test",
srcs = ["gpu_ftz_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/core:test_main",
@@ -79,9 +77,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_index_test",
srcs = ["gpu_index_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -102,9 +98,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_infeed_test",
srcs = ["infeed_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -125,9 +119,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_kernel_tiling_test",
srcs = ["gpu_kernel_tiling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
@@ -142,7 +134,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ldg_test",
srcs = ["gpu_ldg_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -159,9 +151,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_noalias_test",
srcs = ["gpu_noalias_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -178,9 +168,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_fusion_test",
srcs = ["gpu_fusion_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -194,9 +182,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_unrolling_test",
srcs = ["gpu_unrolling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -211,9 +197,7 @@ tf_cc_test(
name = "gpu_alignment_test",
testonly = True,
srcs = ["gpu_alignment_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -225,3 +209,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cudnn_fused_convolution_rewriter_test",
+ srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
new file mode 100644
index 0000000000..5632cac186
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnFusedConvolutionRewriterTest : public HloTestBase {
+ protected:
+ string GetOptimizedHlo(absl::string_view hlo_string) {
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest())
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString();
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convForward"))
+ << optimized_hlo_string;
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo_string;
+ EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ string optimized_hlo = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convForward"))
+ << optimized_hlo;
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo;
+ }
+ }
+};
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest,
+ TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index e0f3a7e0e2..9220865867 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -18,14 +18,16 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::FlatMap;
-using tensorflow::gtl::FlatSet;
+using absl::flat_hash_map;
+using absl::flat_hash_set;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
@@ -56,7 +58,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
@@ -88,7 +90,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
/*schedule=*/nullptr, memory_by_computation);
@@ -115,8 +117,10 @@ Status HeapSimulator::RunComputation(
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
// instruction, so that we can remove the instructions from a buffer's live
// set after they are visited.
- FlatMap<const BufferValue*, FlatSet<const HloInstruction*>> live_buffers;
- FlatMap<const HloInstruction*, FlatSet<const BufferValue*>> used_buffers;
+ flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
+ live_buffers;
+ flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
+ used_buffers;
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
const HloInstruction* user,
const BufferValue* buffer) {
@@ -213,7 +217,7 @@ Status HeapSimulator::RunComputation(
VLOG(4) << " Removing user " << instruction->name() << " from buffer "
<< operand_buffer->ToString();
auto it = live_buffers.find(operand_buffer);
- FlatSet<const HloInstruction*>* live_set = &it->second;
+ flat_hash_set<const HloInstruction*>* live_set = &it->second;
live_set->erase(instruction);
if (live_set->empty()) {
live_buffers.erase(it);
@@ -235,7 +239,8 @@ Status HeapSimulator::RunComputation(
// that we should assign.
// Make sure each buffer get reused at most once.
- FlatSet<const BufferValue*> reused_buffers;
+ flat_hash_set<const BufferValue*> reused_buffers;
+ int64 alloc_size_by_instruction = 0;
for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
@@ -268,14 +273,15 @@ Status HeapSimulator::RunComputation(
if (!shared) {
VLOG(3) << " Allocating: " << buffer->ToString();
+ alloc_size_by_instruction += size_fn_(*buffer);
Alloc(buffer, instruction);
}
}
// Account for the memory used by subcomputations when estimating the
// current heap size.
if (memory_by_computation_ != nullptr) {
- algorithm_->AccountForSubcomputationMemory(instruction,
- *memory_by_computation_);
+ algorithm_->AccountForSubcomputationMemory(
+ instruction, alloc_size_by_instruction, *memory_by_computation_);
}
// If all computations in the module have been scheduled, we can save memory
@@ -323,7 +329,7 @@ Status HeapSimulator::RunComputation(
to_free.reserve(live_buffers.size());
for (const auto& buffer_pending : live_buffers) {
const BufferValue* buffer = buffer_pending.first;
- const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
+ const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
to_free.push_back(buffer);
@@ -345,7 +351,7 @@ HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
const HloSchedule* schedule,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation)
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
@@ -381,10 +387,8 @@ void HeapSimulator::Alloc(const BufferValue* buffer,
allocated_buffers_.insert(buffer);
const int64 size = size_fn_(*buffer);
- const HloInstruction* instruction_to_calc_aliasing =
- memory_by_computation_ == nullptr ? nullptr : instruction;
- algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing);
- no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing);
+ algorithm_->Alloc(buffer, size);
+ no_fragmentation_stats_->Alloc(buffer, size);
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
nullptr);
}
@@ -522,21 +526,9 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
-void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
- const HloInstruction* instruction) {
- // The output buffer of while/call/conditional is always aliased with the
- // output buffer of the root instruction in the body. Don't double count.
- if (instruction == nullptr ||
- (instruction->opcode() != HloOpcode::kWhile &&
- instruction->opcode() != HloOpcode::kCall &&
- instruction->opcode() != HloOpcode::kConditional)) {
- Alloc(buffer, size);
- }
-}
-
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
- const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const HloInstruction* instruction, int64 alloc_size_by_instruction,
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// We only count the memory usage of the largest subcomputation, instead of
// adding them all, because subcomputations won't execute in parallel.
@@ -550,6 +542,14 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
}
}
}
+ if (max_subcomputation_bytes > 0 &&
+ (instruction->opcode() == HloOpcode::kWhile ||
+ instruction->opcode() == HloOpcode::kCall ||
+ instruction->opcode() == HloOpcode::kConditional)) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ max_subcomputation_bytes -= alloc_size_by_instruction;
+ }
max_heap_size_ =
std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
}
@@ -736,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() {
return result_;
}
+void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ result_.chunk_map.emplace(buffer, Chunk{0, 0});
+ return;
+ }
+ auto emplace_result = buffer_intervals_.emplace(
+ buffer, BufferInterval{buffer, size, current_time_, -1});
+ DCHECK(emplace_result.second);
+ ++current_time_;
+}
+
+void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ return;
+ }
+ BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
+ DCHECK_EQ(buffer_interval.buffer, buffer);
+ DCHECK_EQ(buffer_interval.size, size);
+ DCHECK_EQ(buffer_interval.end, -1);
+ buffer_interval.end = current_time_;
+ ++current_time_;
+}
+
+namespace {
+
+// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
+// and the chunk assigned to it.
+struct BufferIntervalTreeNode {
+ // Alloc time.
+ int64 start;
+ // Free time.
+ int64 end;
+ // Maximum free time of all nodes in the subtree where this node is the root.
+ int64 subtree_end;
+ // Allocated chunk for the buffer.
+ HeapSimulator::Chunk chunk;
+ // Left child.
+ BufferIntervalTreeNode* left;
+ // Right child.
+ BufferIntervalTreeNode* right;
+};
+
+// An interval tree that can query buffers overlapping in time.
+class BufferIntervalTree {
+ public:
+ explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {}
+
+ using Chunk = HeapSimulator::Chunk;
+
+ // Adds a buffer to the interval tree, with the time interval and allocated
+ // chunk specified.
+ void Add(int64 start, int64 end, const Chunk& chunk) {
+ int index = node_count_;
+ DCHECK_LT(index, node_storage_.size());
+ ++node_count_;
+
+ node_storage_[index] =
+ BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr};
+
+ if (index == 0) {
+ // This is root.
+ return;
+ }
+
+ BufferIntervalTreeNode* parent = &node_storage_[0];
+ while (true) {
+ parent->subtree_end = std::max(parent->subtree_end, end);
+ if (parent->start > start) {
+ if (parent->left == nullptr) {
+ parent->left = &node_storage_[index];
+ return;
+ }
+ parent = parent->left;
+ } else {
+ if (parent->right == nullptr) {
+ parent->right = &node_storage_[index];
+ return;
+ }
+ parent = parent->right;
+ }
+ }
+ }
+
+ // Returns vector of allocated chunks that overlap with the given time
+ // interval.
+ std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
+ std::vector<Chunk> result;
+ if (node_count_ == 0) {
+ return result;
+ }
+ std::vector<BufferIntervalTreeNode*> visiting_stack;
+ visiting_stack.push_back(&node_storage_[0]);
+ while (!visiting_stack.empty()) {
+ BufferIntervalTreeNode* top = visiting_stack.back();
+ visiting_stack.pop_back();
+ if (start > top->subtree_end) {
+ continue;
+ }
+ if (top->left != nullptr) {
+ visiting_stack.push_back(top->left);
+ }
+ if (top->start <= end && top->end >= start) {
+ result.push_back(top->chunk);
+ }
+ if (end < top->start) {
+ continue;
+ }
+ if (top->right != nullptr) {
+ visiting_stack.push_back(top->right);
+ }
+ }
+ return result;
+ }
+
+ private:
+ int64 node_count_ = 0;
+ std::vector<BufferIntervalTreeNode> node_storage_;
+};
+
+} // namespace
+
+HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
+ std::vector<BufferInterval> sorted_buffer_intervals;
+ for (auto& entry : buffer_intervals_) {
+ sorted_buffer_intervals.push_back(entry.second);
+ }
+ std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
+ [](const BufferInterval& x, const BufferInterval& y) {
+ if (x.size != y.size) {
+ return x.size > y.size;
+ }
+ if (x.end - x.start != y.end - y.start) {
+ return x.end - x.start > y.end - y.start;
+ }
+ return x.buffer->id() < y.buffer->id();
+ });
+
+ BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
+ for (auto& buffer_interval : sorted_buffer_intervals) {
+ auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
+ buffer_interval.start, buffer_interval.end);
+ std::sort(
+ chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
+ [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
+
+ // Find the minimum free chunk that can hold this buffer.
+ Chunk min_fit_chunk{-1, INT64_MAX};
+ auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
+ if (free_size < buffer_interval.size) {
+ return;
+ }
+
+ if (free_size < min_fit_chunk.size) {
+ min_fit_chunk = {free_offset, free_size};
+ }
+ };
+
+ int64 offset = 0;
+ for (auto& chunk : chunks_overlapping_in_time) {
+ if (offset < chunk.offset) {
+ use_free_chunk_if_smaller(offset, chunk.offset - offset);
+ }
+ offset =
+ std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
+ }
+ use_free_chunk_if_smaller(offset, result_.heap_size - offset);
+
+ if (min_fit_chunk.offset == -1) {
+ // Increase the heap size to fit in the last free chunk.
+ result_.heap_size = offset + buffer_interval.size;
+ min_fit_chunk = {offset, buffer_interval.size};
+ }
+
+ min_fit_chunk.size = buffer_interval.size;
+ const auto emplace_result =
+ result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk);
+ DCHECK(emplace_result.second);
+
+ interval_tree.Add(buffer_interval.start, buffer_interval.end,
+ min_fit_chunk);
+ }
+ return result_;
+}
+
+HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
+ DCHECK(!algorithms_.empty());
+ std::vector<Result> results(algorithms_.size());
+ int64 min_size = INT64_MAX;
+ int min_size_index = -1;
+ for (int i = 0; i < algorithms_.size(); ++i) {
+ results[i] = algorithms_[i]->Finish();
+ if (results[i].heap_size < min_size) {
+ min_size = results[i].heap_size;
+ min_size_index = i;
+ }
+ }
+
+ DCHECK_GE(min_size_index, 0);
+ return results[min_size_index];
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index ffbf947d5a..dbbf43082f 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -30,8 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -58,7 +58,7 @@ class HeapSimulator {
// Result represents the result of the heap simulation.
struct Result {
// The assignment of buffers to chunks.
- tensorflow::gtl::FlatMap<const BufferValue*, Chunk> chunk_map;
+ absl::flat_hash_map<const BufferValue*, Chunk> chunk_map;
// The total size in bytes of the heap, containing all assigned chunks.
int64 heap_size = 0;
@@ -100,7 +100,7 @@ class HeapSimulator {
const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
// Run the heap simulation with the given algorithm, assuming the given
@@ -130,7 +130,7 @@ class HeapSimulator {
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
const Options& options = Options(),
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
private:
@@ -140,7 +140,7 @@ class HeapSimulator {
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn,
const Options& options, const HloSchedule* schedule = nullptr,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
~HeapSimulator();
@@ -172,7 +172,7 @@ class HeapSimulator {
// handle subcomputations. It would be good to unify the handling of
// subcomputations, but it's not clear how.
const HloSchedule* schedule_;
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation_;
// In addition to Alloc and Free, the heap simulator exposes a concept of
@@ -193,12 +193,12 @@ class HeapSimulator {
const BufferValue* canonical = nullptr;
int64 refcount = 0;
};
- tensorflow::gtl::FlatMap<const BufferValue*, std::shared_ptr<SharedGroup>>
+ absl::flat_hash_map<const BufferValue*, std::shared_ptr<SharedGroup>>
shared_buffers_;
// Hold some sets for error-checking the sequence of Alloc and Free calls.
- tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_;
- tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_;
+ absl::flat_hash_set<const BufferValue*> allocated_buffers_;
+ absl::flat_hash_set<const BufferValue*> freed_buffers_;
// Debugging information filled in while the heap simulator runs.
HeapSimulatorTrace debug_trace_;
@@ -218,12 +218,6 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
- // NoFragmentationStatsHeap overrides this method.
- virtual void Alloc(const BufferValue* buffer, int64 size,
- const HloInstruction* instruction) {
- Alloc(buffer, size);
- }
-
// Takes memory usage of subcomputations into account when calculating the
// memory usage of a computation. Currently, we don't handle buffer aliasing
// between computations entirely correctly. We are careful to not double count
@@ -235,7 +229,9 @@ class HeapAlgorithm {
// analysis, it's not worth making major changes to HeapSimulator now.
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ // The total number of bytes allocated by instruction.
+ int64 alloc_size_by_instruction,
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {}
// Free de-allocates a previously allocated buffer.
@@ -257,12 +253,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
void Alloc(const BufferValue* buffer, int64 size) override;
- void Alloc(const BufferValue* buffer, int64 size,
- const HloInstruction* instruction) override;
-
void AccountForSubcomputationMemory(
- const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const HloInstruction* instruction, int64 alloc_size_by_instruction,
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) override;
void Free(const BufferValue* buffer, int64 size) override;
@@ -351,6 +344,67 @@ class LazyBestFitHeap : public HeapAlgorithm {
std::set<Chunk, OrderChunkByIncreasingSize> free_;
};
+// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
+// then allocates them in decreasing sizes regardless of the alloc/free time. It
+// internally tracks the allocated buffers and their live intervals; when
+// allocating a buffer, it finds the best-fit free chunk during its live
+// interval.
+class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
+ public:
+ GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {}
+ ~GlobalDecreasingSizeBestFitHeap() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
+ Result Finish() override;
+
+ private:
+ int64 alignment_;
+ Result result_;
+
+ // The current time represented as an integer. It increments by 1 at each
+ // Alloc or Free call.
+ int64 current_time_ = 0;
+
+ // BufferInterval stores a buffer's size and time interval.
+ struct BufferInterval {
+ const BufferValue* buffer;
+ int64 size;
+ // Alloc time of the buffer.
+ int64 start;
+ // Free time of the buffer.
+ int64 end;
+ };
+ absl::flat_hash_map<const BufferValue*, BufferInterval> buffer_intervals_;
+};
+
+// A heap algorithm that chooses the best results from other algorithms added to
+// it.
+class ChooseBestHeapAlgorithm : public HeapAlgorithm {
+ public:
+ ChooseBestHeapAlgorithm(
+ std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms)
+ : algorithms_(std::move(*algorithms)) {}
+ ~ChooseBestHeapAlgorithm() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Alloc(buffer, size);
+ }
+ }
+
+ void Free(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Free(buffer, size);
+ }
+ }
+
+ Result Finish() override;
+
+ private:
+ std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 957c4a6891..e30e7667f3 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -98,6 +98,124 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
}
+TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
+ // HloModule SubcomputationAccounting
+
+ // %WhileBody (body_param: f32[4]) -> f32[4] {
+ // %body_param = f32[4]{0} parameter(0)
+ // %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
+ // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
+ // %constant.1)
+ // }
+
+ // %WhileCond (cond_param: f32[4]) -> pred[] {
+ // %cond_param = f32[4]{0} parameter(0)
+ // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
+ // %reshape = f32[] reshape(f32[1]{0} %slice)
+ // %constant = f32[] constant(0)
+ // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant)
+ // }
+
+ // ENTRY %SubcomputationAccounting () -> f32[2,4] {
+ // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
+ // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
+ // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
+ // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
+ // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
+ // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
+ // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
+ // }
+
+ auto module = CreateNewVerifiedModule();
+ const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // reshape(slice(param)) != 0
+ // Needs 5 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* slice =
+ cond_builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
+ HloInstruction* reshape =
+ cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
+ HloInstruction* zero = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ HloInstruction* cond_comparison =
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector =
+ body_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
+ HloInstruction* subtract =
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ // transpose(matrix) + bcast(while)
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
+ // Creates 16 bytes, ignoring subcomputations
+ HloInstruction* while_loop =
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ // Creates 32 bytes and frees 16
+ HloInstruction* bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
+
+ HloInstruction* matrix = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
+ {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
+ // Creates 32 bytes
+ HloInstruction* transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
+
+ // Creates 32 bytes and frees 64
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
+
+ auto entry_computation = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
+ cond_comparison};
+ std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
+ subtract};
+ std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
+ matrix, transpose, add};
+ schedule.set_sequence(cond_computation, cond_vec);
+ schedule.set_sequence(body_computation, while_body_vec);
+ schedule.set_sequence(entry_computation, entry_comp_vec);
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 5;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator accounts for subcomputations. The output buffer is aliased,
+ // so we don't double count.
+ EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, schedule.sequence(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
+}
+
const char kAlloc[] = "Alloc";
const char kFree[] = "Free";
const char kFinish[] = "Finish";
@@ -174,7 +292,7 @@ class HeapSimulatorTracker {
// Construct the module sequence grouped by computation.
HloSchedule schedule(module_.get());
- tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
+ absl::flat_hash_map<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
schedule.GetOrCreateSequence(instruction->parent())
@@ -1021,5 +1139,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) {
EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
}
+class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(0, result.heap_size);
+ EXPECT_EQ(0, result.chunk_map.size());
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
+ // space
+ // ^
+ // | +---a---+
+ // | +-------+
+ // | +---c---+
+ // | +-------+
+ // | | b |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +-------+
+ // -----------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 30);
+ heap.Alloc(buffer_c_, 20);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Free(buffer_b_, 30);
+ heap.Free(buffer_c_, 20);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(100, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +---a---+ +-------+
+ // |
+ // | +-------+
+ // | | |
+ // | | c |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 50);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 50);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(120, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | d |
+ // | +--a--+ +-------+
+ // | +-------+
+ // | | |
+ // | | c |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | e |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 30);
+ heap.Alloc(buffer_e_, 50);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 40);
+ heap.Free(buffer_d_, 30);
+ heap.Free(buffer_e_, 50);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(140, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index b19ec12638..1ea26ddd5b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 53
+// Next ID: 56
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -124,9 +124,13 @@ message HloInstructionProto {
// The string representation of the infeed configuration.
bytes infeed_config = 27;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a external target (eg, global symbol) to call, only present for
+ // kCustomCall.
string custom_call_target = 28;
+ // Opaque string, only present for kCustomCall.
+ string custom_call_opaque = 53;
+
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
@@ -176,6 +180,10 @@ message HloInstructionProto {
// Collective permute field.
repeated SourceTarget source_target_pairs = 52;
+
+ // Sharding for kDomain instructions.
+ xla.OpSharding domain_entry_sharding = 54;
+ xla.OpSharding domain_exit_sharding = 55;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 0986da65cb..c3da12e273 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -119,7 +121,7 @@ class BufferValueMap {
}
// Return a set of all the values in the given buffer.
- const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
+ const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
BufferNumber buffer_number) const {
return buffers_.at(buffer_number);
}
@@ -142,7 +144,7 @@ class BufferValueMap {
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
- tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
+ absl::flat_hash_set<const HloValue*>& old_value_set =
buffers_.at(old_buffer_number);
old_value_set.erase(&value);
if (old_value_set.empty()) {
@@ -290,13 +292,11 @@ class BufferValueMap {
const HloDataflowAnalysis& dataflow_;
// A map containing the set of values contained in each buffer.
- tensorflow::gtl::FlatMap<BufferNumber,
- tensorflow::gtl::FlatSet<const HloValue*>>
+ absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
buffers_;
// A map indicating which buffer each value is contained in.
- tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
- value_to_buffer_number_;
+ absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
// The buffer number of the next buffer to be created.
BufferNumber next_buffer_number_ = 0;
@@ -352,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
bool HloAliasAnalysis::InstructionBuffersAreDistinct(
const HloInstruction* instruction) const {
- tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen;
+ absl::flat_hash_set<const HloBuffer*> buffers_seen;
for (const auto& pair :
dataflow_analysis_->GetInstructionValueSet(instruction)) {
const HloValueSet& value_set = pair.second;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index e345804537..372f99ff01 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
@@ -110,7 +111,7 @@ class HloAliasAnalysis {
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
// A map indicating which buffer a value is contained in.
- tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
+ absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_;
// A lazily constructed vector containing all HloBuffers sorted by
// HloBuffer::Id.
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index 6c11a073b7..9c3aa0e64d 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h
index 658643b427..24910ca07b 100644
--- a/tensorflow/compiler/xla/service/hlo_clone_context.h
+++ b/tensorflow/compiler/xla/service/hlo_clone_context.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -73,12 +73,12 @@ class HloCloneContext {
return FindOrDie(computations_, old_computation);
}
- const tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>&
+ const absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
cloned_instructions() const {
return instructions_;
}
- const tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>&
+ const absl::flat_hash_map<const HloComputation*, HloComputation*>&
cloned_computations() const {
return computations_;
}
@@ -86,10 +86,8 @@ class HloCloneContext {
private:
HloModule* module_;
string suffix_;
- tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
- instructions_;
- tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>
- computations_;
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*> instructions_;
+ absl::flat_hash_map<const HloComputation*, HloComputation*> computations_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 601a008d9f..c2041c4667 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -24,6 +24,8 @@ limitations under the License.
#include <sstream>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
@@ -39,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
-namespace {
-
-// Returns the new name for a fusion parameter when we change its number.
-//
-// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
-// renumbering the parameters, so replace the final number in the name with
-// the updated value.
-string RenameFusionParameter(const string& original_name, int64 new_param_no) {
- const string param_underscore = ".param_";
- size_t index = original_name.rfind(param_underscore);
- if (index == string::npos) {
- return original_name;
- }
- string after_param = original_name.substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (absl::SimpleAtoi(after_param, &numeric_suffix)) {
- return StrCat(original_name.substr(0, index + param_underscore.size()),
- new_param_no);
- }
- return original_name;
-}
-
-} // namespace
-
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name =
- RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
- param_no, param_instruction->shape(), param_name));
+ param_no, param_instruction->shape(), StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
@@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() {
if (removed > 0) {
const int64 param_no = i - removed;
- string param_name =
- RenameFusionParameter(param_instruction->name(), param_no);
- HloInstruction* new_instr =
- AddInstructionInternal(HloInstruction::CreateParameter(
- param_no, param_instruction->shape(), param_name));
+ HloInstruction* new_instr = AddInstructionInternal(
+ HloInstruction::CreateParameter(param_no, param_instruction->shape(),
+ StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
@@ -272,10 +245,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
- TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
- auto inst_it = instruction_iterators_.at(instruction);
- (*inst_it)->set_parent(nullptr);
- instructions_.erase(inst_it);
+ auto inst_it = instruction_iterators_.find(instruction);
+ TF_RET_CHECK(inst_it != instruction_iterators_.end());
+ (*inst_it->second)->set_parent(nullptr);
+ instructions_.erase(inst_it->second);
+ instruction_iterators_.erase(inst_it);
return Status::OK();
}
@@ -304,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
namespace {
// Helper which builds a post order of the HLO call graph.
-void ComputeComputationPostOrder(
- HloComputation* computation,
- tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::vector<HloComputation*>* post_order) {
+void ComputeComputationPostOrder(HloComputation* computation,
+ absl::flat_hash_set<HloComputation*>* visited,
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -324,7 +297,7 @@ void ComputeComputationPostOrder(
void HloComputation::ComputeInstructionPostOrder(
const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
+ absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
@@ -421,7 +394,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
+ absl::flat_hash_map<HloInstruction*, VisitState> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -442,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
- tensorflow::gtl::FlatSet<HloComputation*> visited;
+ absl::flat_hash_set<HloComputation*> visited;
std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
@@ -532,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
- tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
+ absl::flat_hash_map<int64, HloInstruction*> instruction_map;
+ absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
@@ -562,6 +535,28 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
+ TF_RETURN_IF_ERROR([&]() -> Status {
+ std::vector<bool> parameters_seen(parameter_count);
+ int parameters_seen_count = 0;
+ for (auto& instruction : instructions) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ int64 param_no = instruction->parameter_number();
+ TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
+ << "Invalid parameter number. Expected [0, " << parameter_count
+ << "), got " << param_no;
+ TF_RET_CHECK(!parameters_seen[param_no])
+ << "Parameter number " << param_no
+ << " already allocated in this computation";
+ parameters_seen[param_no] = true;
+ parameters_seen_count++;
+ }
+ }
+ TF_RET_CHECK(parameters_seen_count == parameter_count)
+ << "Not all parameters in range [0, " << parameter_count
+ << ") were referenced";
+ return Status::OK();
+ }());
+
auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*fusion_instruction=*/nullptr));
@@ -916,13 +911,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone(
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- context, suffix);
+ /*extras=*/{}, context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context, const string& suffix) {
+ absl::Span<HloInstruction*> extras, HloCloneContext* context,
+ const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
@@ -944,6 +940,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
std::vector<HloInstruction*> postorder;
+ for (HloInstruction* instr : extras) {
+ postorder.push_back(instr);
+ }
for (HloInstruction* instr : MakeInstructionPostOrder()) {
if (HloInstruction* replacement = replace(instr)) {
postorder.push_back(replacement);
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index a880e9ab30..d87ab4bda1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -25,6 +25,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -40,8 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -188,7 +188,7 @@ class HloComputation {
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map);
// Gets the instructions in this computation.
//
@@ -227,7 +227,7 @@ class HloComputation {
void UpdateReachabilityThroughInstruction(
const HloInstruction* instruction, HloReachabilityMap* reachability_map);
- int64 instruction_count() const { return instructions_.size(); }
+ int64 instruction_count() const { return instruction_iterators_.size(); }
// Creates and returns a list of the embedded computations called by this
// computation. This includes all embedded computations called directly or
@@ -333,10 +333,13 @@ class HloComputation {
//
// If replacements maps a key to nullptr, we remove that instruction from the
// new computation.
+ // If additional instructions are used by instructions in replacement map,
+ // they must be passed in post-order in the extras span.
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context = nullptr, const string& suffix = "clone");
+ absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
@@ -411,14 +414,14 @@ class HloComputation {
// cross-replica-sum the union of the dependencies for all participating
// instructions.
using ChannelDependencyMap =
- tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
ChannelDependencyMap ComputeChannelDependencies() const;
enum VisitState { kVisiting, kVisited };
void ComputeInstructionPostOrder(
const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
+ absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
string name_;
int64 unique_id_;
@@ -436,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index f837816cea..4f898ce61c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,6 +76,26 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
+ // Don't constant fold unless it's a net positive or the output is small.
+ if (ShapeUtil::IsArray(instruction->shape())) {
+ int64 elements_in_removed_operands = 0;
+ for (HloInstruction* operand : instruction->operands()) {
+ if (operand->user_count() == 1 &&
+ ShapeUtil::IsArray(operand->shape())) {
+ elements_in_removed_operands +=
+ ShapeUtil::ElementsIn(operand->shape());
+ }
+ }
+ int64 elements_in_constant =
+ ShapeUtil::ElementsIn(instruction->shape());
+
+ static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
+ if (elements_in_constant > elements_in_removed_operands &&
+ elements_in_constant > kMaximumConstantSizeElements) {
+ continue;
+ }
+ }
+
Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
@@ -84,6 +104,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
<< instruction->ToString();
continue;
}
+ VLOG(4) << "Constant folded: " << instruction->ToString();
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 4557983a9c..4a624cc7b8 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// A pass which performs constant folding in order to avoid unnecessary
// computation on constants.
-class HloConstantFolding : public HloPassInterface {
+class HloConstantFolding : public HloModulePass {
public:
absl::string_view name() const override { return "constant_folding"; }
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 3e0def5d26..e45f905f71 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
+const char* const kConstantFoldLargePad = R"(
+ HloModule ConstantFoldLargePad
+
+ ENTRY r {
+ a = f32[1,1,1] constant(f32[1,1,1]{{{7}}})
+ b = f32[] constant(42)
+ ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
+ })";
+
+TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kConstantFoldLargePad));
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_FALSE(result);
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Pad(op::Constant(), op::Constant()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index b76c50bb5b..b2005d3c21 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -201,6 +202,44 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloInstruction::CreateMap(map_shape, operands, map_computation));
}
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape()));
+ std::iota(all_dims.begin(), all_dims.end(), 0);
+
+ auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
+ HloComputation* reduce_computation;
+ {
+ HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
+ auto lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ b.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
+ reduce_computation = module->AddEmbeddedComputation(b.Build());
+ }
+
+ return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
+ scalar_shape, operand, init_value, all_dims, reduce_computation));
+}
+
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ HloComputation* computation = pred->parent();
+ DCHECK_EQ(computation, on_true->parent());
+ DCHECK_EQ(computation, on_false->parent());
+ TF_ASSIGN_OR_RETURN(Shape select_shape,
+ ShapeInference::InferTernaryOpShape(
+ HloOpcode::kSelect, pred, on_true, on_false));
+ return computation->AddInstruction(HloInstruction::CreateTernary(
+ select_shape, HloOpcode::kSelect, pred, on_true, on_false));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index b22058abb4..8e5ddbbd50 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -107,6 +108,35 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
+// Creates a Reduce HLO instruction and adds it to the computation containing
+// the operand. This will create the sub-computation needed for the reduction in
+// the given module. binary_opcode should represent a binary operation.
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module);
+
+// Creates a Select HLO instruction and adds it to the computation containing
+// the predicate. The on_true and on_false instructions must also be contained
+// in the same computation.
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false);
+
+// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
+// given values and adds it to the given computation.
+template <typename NativeT>
+StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
+ PrimitiveType type,
+ absl::Span<const NativeT> values) {
+ Literal literal = LiteralUtil::CreateR1<NativeT>(values);
+ if (literal.shape().element_type() != type) {
+ TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
+ }
+ return computation->AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+}
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index b59c9ba3ed..e602107cbe 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
// HLO instructions are grouped into equivalency classes by using the
// cse_equal predicate defined above. This set holds a representative
// instruction for each class.
- tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
- decltype(cse_equal)>
+ absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
+ decltype(cse_equal)>
representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
cse_equal);
for (auto instruction : computation->MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index a28c03599a..e4857fd3fd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -25,7 +25,7 @@ namespace xla {
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
-class HloCSE : public HloPassInterface {
+class HloCSE : public HloModulePass {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 6a63681996..c22adcdd8d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
- tensorflow::gtl::FlatSet<const HloInstruction*> visited;
+ absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(inst);
while (!stack.empty()) {
@@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
void HloDataflowAnalysis::DeleteMarkedValues() {
#ifndef NDEBUG
// Verify that no marked-for-deletion values are in any of the value sets.
- tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
- value_ids_to_delete_.end());
+ absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
+ value_ids_to_delete_.end());
for (const auto& pair : value_sets_) {
const HloInstruction* instruction = pair.first;
const InstructionValueSet& instruction_value_set = pair.second;
@@ -355,23 +356,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
return false;
}
-bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) {
- CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
- if (!slice->IsInPlaceSlice()) {
- return false;
- }
- // If this slice is lowered to an in-place version, then it forwards the
- // operand value to the output.
- const InstructionValueSet& operand_set =
- GetInstructionValueSet(slice->operand(0));
- InstructionValueSet& slice_set = GetInstructionValueSet(slice);
- if (operand_set != slice_set) {
- slice_set = operand_set;
- return true;
- }
- return false;
-}
-
bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
CHECK_EQ(send->opcode(), HloOpcode::kSend);
bool changed = false;
@@ -640,8 +624,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
return UpdateBitcastValueSet(instruction);
- case HloOpcode::kSlice:
- return UpdateSliceValueSet(instruction);
case HloOpcode::kDomain:
return UpdateDomainValueSet(instruction);
case HloOpcode::kCopy:
@@ -673,7 +655,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
- tensorflow::gtl::FlatSet<HloInstruction*> workset;
+ absl::flat_hash_set<HloInstruction*> workset;
auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
if (workset.insert(instruction).second) {
worklist.push(instruction);
@@ -813,11 +795,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values();
}
break;
- case HloOpcode::kSlice:
- if (!instruction->IsInPlaceSlice()) {
- define_all_values();
- }
- break;
case HloOpcode::kWhile:
case HloOpcode::kCall:
case HloOpcode::kConditional:
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index e62c1c2ac8..abac398c04 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -182,7 +182,6 @@ class HloDataflowAnalysis {
// Updates the value set for a particular instruction type. Returns whether
// the instruction value set changed.
bool UpdateBitcastValueSet(HloInstruction* bitcast);
- bool UpdateSliceValueSet(HloInstruction* slice);
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 1fe69b1395..4012042672 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -33,7 +33,7 @@ namespace xla {
//
// This pass does not remove dead parameter instructions, as parameter
// instructions cannot be deleted.
-class HloDCE : public HloPassInterface {
+class HloDCE : public HloModulePass {
public:
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index d36631fc2f..c0bf1b9e16 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -30,7 +30,7 @@ namespace xla {
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
// kDomain instruction will be placed.
-class HloDomainIsolator : public HloPassInterface {
+class HloDomainIsolator : public HloModulePass {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 113fd18eae..c6d02f9f67 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <algorithm>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -40,18 +42,19 @@ namespace xla {
return std::move(domain_map);
}
-bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const {
+bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const {
int64 domain_id1 = GetDomainId(instruction1);
int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
-int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
-int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainMetadataId(
+ const HloInstruction* instruction) const {
return FindOrDie(domain_metadata_id_, instruction);
}
@@ -106,8 +109,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() {
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
return a->Matches(*b);
};
- tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
- decltype(equal)>
+ absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash),
+ decltype(equal)>
domain_metadata(1024, hash, equal);
for (auto& domain : instruction_domains_) {
@@ -198,7 +201,8 @@ StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
return std::move(domain);
}
-bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
+bool HloDomainMap::IsDomainInstruction(
+ const HloInstruction* instruction) const {
if (instruction->opcode() != HloOpcode::kDomain) {
return false;
}
@@ -216,7 +220,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
/* static */ std::vector<HloInstruction*>
HloDomainMap::MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order) {
std::vector<HloInstruction*> instructions;
instructions.reserve(instruction_set.size());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 56b557d7ce..bce7d1aa7c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -19,14 +19,14 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -58,27 +58,26 @@ class HloDomainMap {
}
// Checks whether two instructions are within the same domain.
- bool InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const;
+ bool InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const;
// Checks whether instruction is a kDomain instruction of the kind we are
// currently processing.
- bool IsDomainInstruction(HloInstruction* instruction) const;
+ bool IsDomainInstruction(const HloInstruction* instruction) const;
// Retrieves the domain identifier of the instruction, or -1 in case
// instruction is not found within any domain.
- int64 GetDomainId(HloInstruction* instruction) const;
+ int64 GetDomainId(const HloInstruction* instruction) const;
// Returns the unique id of the domain metadata for the domain the given
// instruction belongs to. The given instruction must not be a kDomain
// instruction since each domain instruction is associated with 2 domains.
- int64 GetDomainMetadataId(HloInstruction* instruction) const;
+ int64 GetDomainMetadataId(const HloInstruction* instruction) const;
private:
// Map used for representing instruction ordering, i.e.
// order_map[a] < order_map[b] means a must be ordered before b.
- using InstructionOrderMap =
- tensorflow::gtl::FlatMap<const HloInstruction*, int64>;
+ using InstructionOrderMap = absl::flat_hash_map<const HloInstruction*, int64>;
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
@@ -111,7 +110,7 @@ class HloDomainMap {
// Out of an instruction set, returns a vector of all the ones which are not
// a kDomain kind.
static std::vector<HloInstruction*> MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
@@ -120,8 +119,8 @@ class HloDomainMap {
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
+ absl::flat_hash_map<const HloInstruction*, int64> instruction_to_domain_;
+ absl::flat_hash_map<const HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 302807f816..d3c83c15ae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -42,7 +42,7 @@ class DomainMetadata {
// operand/user pathways, without crossing a kDomain instruction of a given
// kind. The reach_set can contain kDomain instructions of other kinds, if
// two domains of different kind intersect each other.
- tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
+ absl::flat_hash_set<HloInstruction*> reach_set;
// The same instructions in reach_set, but purged from kDomain instructions
// and ordered according to their computation graph post-order, i.e.
@@ -55,8 +55,8 @@ class DomainMetadata {
// whose dataflow enters the reach set (domain), while the exit_domains
// contains the set of kDomain instructions whose dataflow exit the reach
// set.
- tensorflow::gtl::FlatSet<HloInstruction*> enter_domains;
- tensorflow::gtl::FlatSet<HloInstruction*> exit_domains;
+ absl::flat_hash_set<HloInstruction*> enter_domains;
+ absl::flat_hash_set<HloInstruction*> exit_domains;
};
virtual ~DomainMetadata() = default;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index 97bc8ef604..0fc30fb86c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -26,7 +26,7 @@ namespace xla {
// Removes all the kDomain instructions of a given kind from the input module,
// and calls the normalizer to propagate the properties on the possibly new born
// instructions.
-class HloDomainRemover : public HloPassInterface {
+class HloDomainRemover : public HloModulePass {
public:
// Creates a new HloDomainRemover object tasked at removing all the kDomain
// instructions of a given kind.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 81d6d69a8c..bea5cba38d 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -29,7 +29,7 @@ namespace xla {
// Verifies that the domain instructions are consistent, and the each domain is
// surrounded by the same metadata.
-class HloDomainVerifier : public HloPassInterface {
+class HloDomainVerifier : public HloModulePass {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 44ded2c2fa..4d2a942925 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// inserting Convert ops. This allows a backend to support an element type while
// only actually implementing the Convert op for that element type. This is
// generally not the fastest approach, but it works.
-class HloElementTypeConverter : public HloPassInterface {
+class HloElementTypeConverter : public HloModulePass {
public:
// eliminate_type is the type to eliminate as the input or output of ops,
// using Convert ops to replace it with replace_with_type.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 06b6d5b559..eec8d242fa 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
return Status::OK();
}
+Status HloEvaluator::HandleReal(HloInstruction* real) {
+ auto operand = real->operand(0);
+ switch (operand->shape().element_type()) {
+ case BF16: {
+ auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
+ real, [](bfloat16 elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case C64: {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ real, [](complex64 elem_operand) { return std::real(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F16: {
+ auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
+ real, [](Eigen::half elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<float, float>(
+ real, [](float elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<double, double>(
+ real, [](double elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleImag(HloInstruction* imag) {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
+ GetEvaluatedLiteralFor(imag->operand(0)));
+
+ TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
+ return Status::OK();
+}
+
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
auto lhs = compare->operand(0);
@@ -1173,80 +1228,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
- << rank;
TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
- // We need to sort and array of keys and an array of values, where the
+ // We need to sort an array of keys and an array of values, where the
// sorted order of the values is determined by the keys. The simplest(?)
// way to do this is to go to an array-of-pairs representation, sort the
// array using the keys, and then go back to pair-of-arrays.
VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
- auto sort_r1 = [](const Literal& keys_literal,
- const Literal& values_literal) {
- const auto& keys_data = keys_literal.data<KeyType>();
- const auto& values_data = values_literal.data<ValueType>();
-
- using kv_pair = std::pair<KeyType, ValueType>;
- std::vector<kv_pair> key_value_vector;
- CHECK_EQ(keys_data.size(), values_data.size());
- key_value_vector.reserve(keys_data.size());
- for (int i = 0; i < keys_data.size(); ++i) {
- key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
- }
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
- std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
- for (const auto& key_value : key_value_vector) {
- result_keys.push_back(key_value.first);
- result_values.push_back(key_value.second);
- }
- Literal result_keys_literal(keys_literal.shape());
- result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
- Literal result_values_literal(values_literal.shape());
- result_values_literal.PopulateR1(
- absl::Span<const ValueType>(result_values));
- return std::make_pair(std::move(result_keys_literal),
- std::move(result_values_literal));
- };
-
- Literal result_tuple;
- if (rank == 1) {
- auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple =
- LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal keys_result_literal(keys_literal.shape());
- Literal values_result_literal(values_literal.shape());
- int64 r1_length = keys_literal.shape().dimensions(1);
- for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- TF_ASSIGN_OR_RETURN(auto values_r1_slice,
- values_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
- TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first.Reshape({1, r1_length}));
- TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
- sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
- sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
- }
- result_tuple =
- LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
+ if (rank == 0) {
+ // Nothing to sort.
+ return LiteralUtil::MakeTuple({&keys_literal, &values_literal});
}
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys_literal.shape(), zero_base,
+ AsInt64Slice(keys_literal.shape().dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the keys and values literals that correspond to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto keys_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& keys_data = keys_to_sort.data<KeyType>();
+ TF_ASSIGN_OR_RETURN(auto values_to_sort,
+ values_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& values_data = values_to_sort.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ Literal sorted_keys(ShapeUtil::MakeShape(
+ keys_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal sorted_values(ShapeUtil::MakeShape(
+ values_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_values.PopulateR1(absl::Span<const ValueType>(result_values));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped,
+ sorted_keys.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys_reshaped, start_indices, indices, slice_dimensions));
+ TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped,
+ sorted_values.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+
+ Literal result_tuple;
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
@@ -1292,15 +1352,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort,
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- const int64 sort_dim = sort->dimensions(0);
- const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
- if (sort_dim != rank - 1) {
- return Unimplemented(
- "Trying to sort along dimension %d, which is not the last "
- "dimension",
- sort_dim);
- }
-
if (!ShapeUtil::IsTuple(sort->shape())) {
return DefaultAction(sort);
} else {
@@ -1327,7 +1378,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
"unsupported");
}
}
- return reduce->Visit(typed_visitors_.at(first_element_type).get());
+ return reduce->Visit(typed_visitors_[first_element_type].get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 21e676d671..07f8d0aad4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/container/node_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -134,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Wraps around instruction handling to infer types before dispatching to
// the corresponding typed Visitor.
Status DefaultAction(HloInstruction* hlo) override {
- return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
+ return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
}
Status Preprocess(HloInstruction* hlo) override;
@@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReal(HloInstruction* real) override;
+
+ Status HandleImag(HloInstruction* imag) override;
+
Status HandleReduce(HloInstruction* reduce) override;
// Returns the already-evaluated literal result for the instruction.
@@ -206,8 +210,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// post-orderring.
// Must be cleared for each evaluation.
// Storing Literal in place require the container to have pointer stability so
- // we cannot use FlatMap any more.
- std::unordered_map<const HloInstruction*, Literal> evaluated_;
+ // we cannot use flat_hash_map any more.
+ absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
@@ -237,12 +241,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
}
// Map from a primitive type to its associated (templated) DfsHloVisitor.
- // Note: the hash function here is only needed because current gcc std::hash
- // does not specialize for enum types. This should however be fixed in the
- // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5
- tensorflow::gtl::FlatMap<PrimitiveType, std::unique_ptr<DfsHloVisitor>,
- std::hash<int>>
- typed_visitors_;
+ std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
// Caches pointers to input literals, assuming they are in post-order.
// Literals are not owned by this class, and they must outlive the lifetime of
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 01e88566a5..608a42bb60 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -66,6 +66,20 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
.ConsumeValueOrDie();
}
+ // Evaluate function that takes in a local module instead of using module_
+ // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is
+ // removed, this should be the default Evaluate function.
+ Literal EvaluateWithModule(
+ HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
+ if (use_bfloat16_) {
+ // In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
+ auto type_converter = HloElementTypeConverter(F32, BF16);
+ type_converter.Run(module).ValueOrDie();
+ }
+ return evaluator_->Evaluate(*module->entry_computation(), arg_literals)
+ .ConsumeValueOrDie();
+ }
+
std::unique_ptr<HloEvaluator> evaluator_;
void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
@@ -1449,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
+TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) {
+ HloComputation::Builder b(TestName());
+
+ // arg:
+ // f32[3,3] {
+ // { 1, 2, 3 },
+ // { 5, 6, 7 },
+ // { 9, 10, 11 },
+ // }
+ auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
+ arg_array->FillUnique(1.0f);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
+
+ HloComputation::Builder max_computation("max");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ max_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
+ auto max_func = module().AddEmbeddedComputation(max_computation.Build());
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(2);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, max_func));
+
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ auto expected = LiteralUtil::CreateR2<float>({{11}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
HloComputation::Builder b(TestName());
@@ -2530,6 +2596,114 @@ ENTRY main {
expected, Evaluate({&operand, &scatter_indices, &updates})));
}
+TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_NegativeIndices
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the negative indices.
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) {
+ const string hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the OOB indices.
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ // Given the update window size of 2,2 and the index of 0,2, the update window
+ // will be OOB. So, nothing should be updated.
+ Literal expected = operand.Clone();
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ expected, EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise comparison with 2 bfloat16 operands.
TEST_P(HloEvaluatorTest, DoesCompareBF16) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8fb17a0033..a450dc6ff5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -78,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// to this rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
+// - HandleImag and HandleReal: where the resulting literal type is always float
+// and the operand is always complex, or real in the case of HandleReal.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
//
@@ -318,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
- Status HandleImag(HloInstruction* imag) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
- ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
- return std::imag(elem_operand);
- }));
- return Status::OK();
- }
-
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -673,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleReal(HloInstruction* real) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
- ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
- return std::real(elem_operand);
- }));
- return Status::OK();
- }
-
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
@@ -1527,47 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}
@@ -2265,19 +2270,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// be 1.
int64 update_dim_size =
update_dim == -1 ? 1 : updates_shape.dimensions(update_dim);
- // Clamp the scatter index so that the scatter region fits in the
- // operand. input_scatter_index_clamped[i] =
- // clamp(input_scatter_index[i], 0,
- // operand_shape.dimensions(i) -
- // update_dim_size);
- input_scatter_index_clamped[i] =
- std::min(operand_shape.dimensions(i) - update_dim_size,
- std::max(0LL, input_scatter_index[i]));
+ // If any part of the update region is out-of-bounds, then do not
+ // perform any update on the input.
+ if ((input_scatter_index[i] < 0) ||
+ (input_scatter_index[i] >
+ operand_shape.dimensions(i) - update_dim_size)) {
+ return true;
+ }
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_scatter_index_clamped[i] + input_window_index[i];
- DCHECK_GE(input_index[i], 0);
- DCHECK_LT(input_index[i], operand_shape.dimensions(i));
+ input_index[i] = input_scatter_index[i] + input_window_index[i];
}
auto result_value_literal =
@@ -2611,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> base_index(rank);
bool out_of_bound = false;
for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
+ base_index[i] =
+ window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] * window.dimensions(i).window_dilation() -
+ window.dimensions(i).padding_low();
+ // We are not in the base area if the dilation placed us out of bounds.
+ if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
+ out_of_bound = true;
+ break;
+ }
+ // Apply the dilation to the base area.
+ base_index[i] /= window.dimensions(i).base_dilation();
if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
out_of_bound = true;
break;
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index de3d7a1677..ce4cad4235 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -90,8 +90,9 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
HloInstructionInfo* instruction_info =
computation_info->add_instruction_infos();
instruction_info->set_long_name(hlo->ToString());
- instruction_info->set_short_name(
- hlo->ToString(HloPrintOptions().set_compact_operands(true)));
+ instruction_info->set_short_name(hlo->ToString(
+ HloPrintOptions().set_compact_operands(true).set_print_operand_names(
+ false)));
instruction_info->set_category(hlo->ToCategory());
instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
instruction_info->set_transcendental_count(
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 287ba84b3b..13a74fd8a1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1110,7 +1110,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
instr->metadata().source_line()));
}
- return StrJoin(lines, "<br/>");
+ return StrJoin(lines, "\n");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e905f2983a..2f6db7cd7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
@@ -37,14 +39,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
@@ -59,8 +60,8 @@ using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
+ const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@@ -80,6 +81,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
+
+ TF_RET_CHECK(std::all_of(
+ proto.operand_ids().begin(), proto.operand_ids().end(),
+ [&instruction_map](int64 id) { return instruction_map.contains(id); }))
+ << proto.name() << " instruction contains invalid operand id(s)";
+
+ TF_RET_CHECK(std::all_of(
+ proto.called_computation_ids().begin(),
+ proto.called_computation_ids().end(),
+ [&computation_map](int64 id) { return computation_map.contains(id); }))
+ << proto.name() << " instruction references invalid computation id(s)";
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
@@ -266,7 +281,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Expect 1 called computation for fusion instruction but sees "
<< proto.called_computation_ids_size();
const int64 fusion_id = proto.called_computation_ids(0);
- auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
+ auto* fused_computation =
+ tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
<< "No fusion computation with id " << fusion_id;
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
@@ -302,6 +318,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
} break;
case HloOpcode::kOutfeed:
TF_RET_CHECK(proto.operand_ids_size() == 2);
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
break;
@@ -379,7 +397,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
case HloOpcode::kCustomCall:
instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target());
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -446,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kIota:
- TF_RET_CHECK(proto.dimensions_size() <= 1)
- << "Iota instruction should have at most 1 dimension but sees "
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
@@ -465,31 +484,34 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dot_dimension_numbers(), precision_config);
break;
}
- case HloOpcode::kDomain:
+ case HloOpcode::kDomain: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Domain instruction should have 1 operands but sees "
<< proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_domain_entry_sharding())
+ << "Domain instruction must domain_entry_sharding";
+ TF_RET_CHECK(proto.has_domain_exit_sharding())
+ << "Domain instruction must domain_exit_sharding";
+ TF_ASSIGN_OR_RETURN(
+ HloSharding entry_hlo_sharding,
+ HloSharding::FromProto(proto.domain_entry_sharding()));
+ TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding,
+ HloSharding::FromProto(proto.domain_exit_sharding()));
instruction = absl::make_unique<HloDomainInstruction>(
- proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
- /*user_side_metadata=*/nullptr);
+ proto.shape(), operands(0),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(entry_hlo_sharding)),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(exit_hlo_sharding)));
break;
+ }
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
- << "No instruction with id " << operand_id;
instruction->AppendOperand(instruction_map.at(operand_id));
}
- for (const int64 predecessor_id : proto.control_predecessor_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
- << "No instruction with id " << predecessor_id;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
- ->AddControlDependencyTo(instruction.get()));
- }
if (instruction->opcode() != HloOpcode::kFusion) {
for (const int64 computation_id : proto.called_computation_ids()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_id))
- << "No computation with id " << computation_id;
instruction->called_computations_.push_back(
computation_map.at(computation_id));
}
@@ -501,6 +523,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
}
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
+ ->AddControlDependencyTo(instruction.get()));
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
@@ -1108,9 +1137,9 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target) {
- return absl::make_unique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -1431,7 +1460,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
- tensorflow::gtl::FlatSet<const HloInstruction*> seen;
+ absl::flat_hash_set<const HloInstruction*> seen;
for (HloInstruction* operand : operands()) {
if (seen.insert(operand).second) {
unique.push_back(operand);
@@ -2005,7 +2034,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
options.is_in_nested_computation()) {
str.push_back(PrintName(
canonical_name_map->LookupOrInsert(operand->name()), options));
- } else if (!options.compact_operands()) {
+ } else if (options.print_operand_names()) {
str.push_back(PrintName(operand->name(), options));
}
StrAppend(out, StrJoin(str, " "));
@@ -2423,7 +2452,7 @@ template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
- visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
+ visitor->ReserveVisitStates(root->GetModule()->instruction_count());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
@@ -2660,14 +2689,14 @@ class HloInstruction::FusionReusesParamElements {
// the value of this parameter, which would save stack space but not allow us
// to finish early if we find a reuse.
static UseKind Compute(int64 i, const HloInstruction& hlo) {
- tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
+ absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
return ComputeInternal(i, hlo, &memoization_cache);
}
private:
static UseKind ComputeInternal(
int64 i, const HloInstruction& hlo,
- tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
+ absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
if (hlo_param->parameter_number() == i) {
return UseKind::kUse;
@@ -2910,6 +2939,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
+ auto lhs_module = lhs->GetModule();
+ auto rhs_module = rhs->GetModule();
+ CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+ (lhs_module != nullptr && rhs_module != nullptr));
+ if (lhs_module != nullptr &&
+ lhs_module->unique_id() != rhs_module->unique_id()) {
+ return lhs_module->unique_id() < rhs_module->unique_id();
+ }
+ return lhs->unique_id() < rhs->unique_id();
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
@@ -3027,10 +3076,6 @@ const std::vector<int64>& HloInstruction::slice_strides() const {
return Cast<HloSliceInstruction>(this)->slice_strides();
}
-bool HloInstruction::IsInPlaceSlice() const {
- return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
-}
-
const Literal& HloInstruction::literal() const {
return Cast<HloConstantInstruction>(this)->literal();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4f6cac1396..374862c4b6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -32,6 +32,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@@ -50,7 +51,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -80,6 +80,7 @@ class HloPrintOptions {
print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
+ print_operand_names_(true),
print_program_shape_(true),
print_percent_(true),
print_control_dependencies_(true),
@@ -107,6 +108,7 @@ class HloPrintOptions {
.set_print_metadata(false)
.set_print_backend_config(false)
.set_compact_operands(true)
+ .set_print_operand_names(false)
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
@@ -144,6 +146,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, the operand names will be printed.
+ HloPrintOptions& set_print_operand_names(bool value) {
+ print_operand_names_ = value;
+ return *this;
+ }
+
// If true, program shape of hlo computations will be printed.
HloPrintOptions& set_print_program_shape(bool value) {
print_program_shape_ = value;
@@ -162,8 +170,8 @@ class HloPrintOptions {
return *this;
}
- // If true, only a part of operands will be printed out, and their names will
- // be omitted (note that in this case the text will not be parsable).
+ // If true, only a part of operands will be printed out (note that in this
+ // case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
compact_operands_ = value;
return *this;
@@ -197,6 +205,7 @@ class HloPrintOptions {
bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
+ bool print_operand_names() const { return print_operand_names_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
bool print_control_dependencies() const {
@@ -215,6 +224,7 @@ class HloPrintOptions {
bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
+ bool print_operand_names_;
bool print_program_shape_;
bool print_percent_;
bool print_control_dependencies_;
@@ -247,7 +257,7 @@ class CanonicalNameMap {
private:
int64 index;
- tensorflow::gtl::FlatMap<string, string> canonical_name_map;
+ absl::flat_hash_map<string, string> canonical_name_map;
};
// HLO instructions are the atomic unit of the high-level compiler's IR.
@@ -350,8 +360,8 @@ class HloInstruction {
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
+ const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
@@ -718,10 +728,11 @@ class HloInstruction {
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
- // to the given operands. "shape" is the resultant shape.
+ // to the given operands. "opaque" can be an arbitrary string with a
+ // backend-specific interpretation. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque = "");
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
@@ -1319,9 +1330,6 @@ class HloInstruction {
int64 slice_strides(int64 dimension) const;
const std::vector<int64>& slice_strides() const;
- // Delegates to HloSliceInstruction::IsInPlaceSlice.
- bool IsInPlaceSlice() const;
-
// Returns the literal associated with this instruction.
const Literal& literal() const;
@@ -1616,6 +1624,10 @@ class HloInstruction {
InstructionVector operands_;
// The set of control predecessors of this instruction.
+ // Note that the order of the instructions in the vector influences the order
+ // computed in HloComputation::ComputeInstructionPostOrder, which may
+ // influence the result of the compilation by changing the scheduling. We are
+ // not sure if it matters.
std::vector<HloInstruction*> control_predecessors_;
// The users of this instruction. Users are HLOs where this instruction is an
@@ -1689,21 +1701,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
// the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
- const HloInstruction* const& rhs) const {
- if (rhs == nullptr) {
- // Nothing compares less than nullptr.
- return false;
- }
- if (lhs == nullptr) {
- return true;
- }
- return lhs->unique_id() < rhs->unique_id();
- }
+ const HloInstruction* const& rhs) const;
};
template <typename ValueT>
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e92882c22a..152d8eacdb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
@@ -27,8 +28,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/window_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_channel_id(channel_id_);
+ proto.set_is_host_transfer(is_host_transfer_);
return proto;
}
@@ -641,14 +643,6 @@ HloTransposeInstruction::HloTransposeInstruction(
absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
- CHECK_EQ(shape.dimensions().size(), dimensions.size());
- CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
- CHECK(std::equal(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end(),
- Permute(dimensions, shape.dimensions()).begin()))
- << "shape: " << ShapeUtil::HumanString(shape)
- << ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -1042,7 +1036,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
const int64 param_no = operand_count();
// Name the parameter after the instruction it represents in the outer
// (non-fusion) computation.
- string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ // string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ string param_name = StrCat("param_", param_no);
HloInstruction* fused_parameter =
fused_instructions_computation()->AddParameter(
HloInstruction::CreateParameter(param_no, new_operand->shape(),
@@ -1098,7 +1093,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
// Note that we add the unfused instructions to this->parent_ computation.
// This is necessary because the unique_id needs for an instruction and
// it's only added when inserting to the computation.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
std::vector<HloInstruction*> unfused_instructions;
auto computation_to_merge =
instruction_to_merge->fused_instructions_computation();
@@ -1391,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
}
Status HloFusionInstruction::DeduplicateFusionOperands() {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ absl::flat_hash_map<const HloInstruction*, int> operand_indices;
std::vector<int> operands_to_remove;
for (int i = 0; i < operand_count(); ++i) {
auto emplace_result = operand_indices.emplace(operand(i), i);
@@ -1488,7 +1483,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
const Shape& shape, HloInstruction* operand, int64 index)
: HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
- CHECK(ShapeUtil::IsTuple(operand->shape()));
AppendOperand(operand);
}
@@ -1610,9 +1604,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
outfeed_config_(outfeed_config) {
- CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
- << "Outfeed shape " << outfeed_shape
- << " must be compatible with operand shape " << operand->shape();
AppendOperand(operand);
AppendOperand(token_operand);
}
@@ -1830,9 +1821,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target)
+ absl::string_view custom_call_target, absl::string_view opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -1849,6 +1841,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1872,6 +1865,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
// an HloComputation.
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ // If the opaque string becomes enormous we may want to reconsider printing
+ // this inline and consider other options.
+ if (!opaque_.empty()) {
+ extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
+ }
return extra;
}
@@ -1897,7 +1895,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (feature_group_count_ != casted_other.feature_group_count_) {
return false;
}
- return custom_call_target_ == casted_other.custom_call_target_;
+ return custom_call_target_ == casted_other.custom_call_target_ &&
+ opaque_ == casted_other.opaque_;
}
std::unique_ptr<HloInstruction>
@@ -1905,7 +1904,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
- shape, new_operands, custom_call_target());
+ shape, new_operands, custom_call_target(), opaque());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
@@ -2301,4 +2300,23 @@ std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
}
+
+HloInstructionProto HloDomainInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ auto operand_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
+ if (operand_side_sharding) {
+ *proto.mutable_domain_entry_sharding() =
+ operand_side_sharding->sharding()->ToProto();
+ }
+
+ auto user_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
+ if (user_side_sharding) {
+ *proto.mutable_domain_exit_sharding() =
+ user_side_sharding->sharding()->ToProto();
+ }
+
+ return proto;
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 2d7bc83855..e169604072 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -546,17 +546,6 @@ class HloSliceInstruction : public HloInstruction {
}
const std::vector<int64>& slice_strides() const { return slice_strides_; }
- // Returns the flag that describes whether a slice must be lowered into an
- // offset into the original operand.
- bool IsInPlaceSlice() const { return is_in_place_slice_; }
-
- // Sets and returns the flag that describes whether a slice must be lowered
- // into an offset into the original operand.
- bool SetIsInPlaceSlice(bool value) {
- is_in_place_slice_ = value;
- return value;
- }
-
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@@ -573,9 +562,6 @@ class HloSliceInstruction : public HloInstruction {
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
std::vector<int64> slice_strides_;
-
- // Describes whether the slice can be lowered to an offset into the operand.
- bool is_in_place_slice_ = false;
};
class HloConstantInstruction : public HloInstruction {
@@ -910,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction {
absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
return outfeed_shape_;
}
// Returns the config for the Outfeed instruction.
@@ -1070,7 +1055,8 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1090,6 +1076,7 @@ class HloCustomCallInstruction : public HloInstruction {
convolution_dimension_numbers_ =
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
+ const string& opaque() const { return opaque_; }
const string& custom_call_target() const { return custom_call_target_; }
void set_feature_group_count(int64 feature_group_count) {
feature_group_count_ = feature_group_count;
@@ -1109,8 +1096,10 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a global symbol to call.
string custom_call_target_;
+ // Opaque string interpreted by the backend.
+ string opaque_;
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
@@ -1337,6 +1326,9 @@ class HloDomainInstruction : public HloInstruction {
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata);
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
// Retrieves the operand side metadata of a kDomain instruction.
const DomainMetadata& operand_side_metadata() const {
return *operand_side_metadata_;
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 3a1dd471c6..5bf055f3c0 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers(
}
}
+// Makes sure that if a live instruction is within a computation used in control
+// flow operations, we mark live even other related instructions.
+void PropagateLivenessThroughControlFlow(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset, CallGraph* call_graph) {
+ const CallGraphNode& call_graph_node =
+ call_graph->GetNode(instruction->parent());
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ HloInstruction* caller = callsite.instruction();
+ if (caller->opcode() == HloOpcode::kWhile) {
+ // If a live instruction is within the %while body or condition
+ // computation, mark the predicate value returned by the condition
+ // computation live as well.
+ MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
+ live_index_map, worklist, workset);
+ } else if (caller->opcode() == HloOpcode::kConditional) {
+ // If a live instruction is within the true or false branches of a
+ // conditional, we mark the predicate operand live as well.
+ MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
+ workset);
+ }
+ }
+ }
+}
+
} // namespace
HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
@@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() {
} else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kWhile &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
PropagateLivenessToParameterCallers(instruction, &live_index_map_,
&worklist, &workset,
call_graph_.get());
@@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() {
MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
}
}
+ PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
+ &worklist, &workset, call_graph_.get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 01b625c29c..e0ae1173c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
}
+TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ InnerWhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ InnerWhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ OuterWhileCondition {
+ cond_param.2 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
+ constant.5 = s32[] constant(5)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
+ }
+ OuterWhileBody {
+ body_param.2 = (s32[]) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
+ constant.6 = s32[] constant(0)
+ tuple.2 = (s32[]) tuple(constant.6)
+ inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
+ body=InnerWhileBody
+ constant.7 = s32[] constant(1)
+ add.2 = s32[] add(get-tuple-element.8, constant.7)
+ ROOT rtuple = (s32[]) tuple(add.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
+ body=OuterWhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index c7ec88d450..5cee865b7a 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -74,7 +76,7 @@ class ListScheduler {
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
ListScheduler scheduler(computation, points_to_analysis, size_function,
memory_by_computation);
@@ -99,7 +101,7 @@ class ListScheduler {
ListScheduler(const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation)
: computation_(computation),
points_to_analysis_(points_to_analysis),
@@ -110,7 +112,7 @@ class ListScheduler {
// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation.instructions()) {
- tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
+ absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/,
@@ -193,13 +195,15 @@ class ListScheduler {
return entry;
}
- // Returns the number of bytes freed if the HLO instruction is scheduled.
- // If the instruction calls subcomputations, we count the memory used by the
- // subcomputations as memory "defined" by the instruction. This is not
- // entirely accurate, because subcomputation memory will be freed after the
- // instruction finishes. But it is more accurate than not taking
- // subcomputations into account at all. In the future, we may improve
- // accounting for subcomputation memory (b/65409243).
+ // Returns the number of bytes freed *after* the HLO instruction finishes.
+ // The current List algorithm only considers two states for an instruction:
+ // right before it runs, and after it finishes. We don't represent memory
+ // usage during the execution of an instruction. But if the instruction calls
+ // subcomputations, they are only live during the instruction's execution.
+ // We end up counting the memory used by subcomputations as memory "defined"
+ // by the instruction. This is not entirely accurate, but it is more accurate
+ // than not taking subcomputations into account at all. In the future, we may
+ // improve accounting for subcomputation memory (b/65409243).
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
@@ -221,7 +225,18 @@ class ListScheduler {
}
}
}
- return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
+ int64 bytes_defined;
+ if (max_subcomputation_bytes > 0 &&
+ (entry.instruction->opcode() == HloOpcode::kWhile ||
+ entry.instruction->opcode() == HloOpcode::kCall ||
+ entry.instruction->opcode() == HloOpcode::kConditional)) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ bytes_defined = max_subcomputation_bytes;
+ } else {
+ bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
+ }
+ return freed_bytes - bytes_defined;
}
// Constructs the scheduling priority of the given instruction.
@@ -234,8 +249,7 @@ class ListScheduler {
// Populate the ready list with instructions which have no operands or
// control predecessors.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64>
- unscheduled_pred_count;
+ absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
for (auto* instruction : computation_.instructions()) {
// TODO(b/34466113): Replace this and above with successors() or
// predecessors() when these methods are added to HloInstruction.
@@ -251,8 +265,8 @@ class ListScheduler {
std::multimap<Priority, ReadyListEntry> ready_queue;
// Map of ready instructions to their iterators in ready_queue.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::multimap<Priority, ReadyListEntry>::iterator>
+ absl::flat_hash_map<const HloInstruction*,
+ std::multimap<Priority, ReadyListEntry>::iterator>
ready_instructions;
auto add_to_ready_queue = [&](HloInstruction* inst) {
@@ -262,9 +276,8 @@ class ListScheduler {
};
for (auto* instruction : computation_.instructions()) {
- // Instruction with no operands or control predecessors will
- // not be in the map.
- if (unscheduled_pred_count.count(instruction) == 0) {
+ if (instruction->operands().empty() &&
+ instruction->control_predecessors().empty()) {
add_to_ready_queue(instruction);
}
}
@@ -347,21 +360,19 @@ class ListScheduler {
// Computations are analyzed in post-order. When scheduling an instruction
// that includes subcomputations, such as a while loop, we use this map to
// look up the memory needed by subcomputations.
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation_;
// A map containing the LogicalBuffers that each instruction uses.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const LogicalBuffer*>>
+ absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
- // LogicalBuffer. We rely on iterator stability in this map, and that the map
- // entries are std::pair's.
- std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
+ // LogicalBuffer.
+ absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
- tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
+ absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
};
int64 SumLogicalBufferSizes(
@@ -379,7 +390,7 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
VLOG(2) << "Computation: " << computation.name();
if (algorithm) {
@@ -396,13 +407,13 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
- int64 total_hlos = computation.parent()->NumUniqueInstructionIds();
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
+ int64 total_hlos = computation.parent()->instruction_count();
+ absl::flat_hash_map<const HloInstruction*, int64> extra_users;
+ absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
if (ListScheduler::IgnoreInstruction(*hlo)) {
extra_users[hlo] = 0;
@@ -419,7 +430,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
total_sizes[hlo] = logical_buffer_size;
cumulative_total_size += logical_buffer_size;
- tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
+ absl::flat_hash_set<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
@@ -467,7 +478,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
return ListScheduler::Run(computation, points_to_analysis, size_function,
memory_by_computation);
@@ -477,7 +488,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
return HloInstructionSequence(computation.MakeInstructionPostOrder());
}
@@ -486,7 +497,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// We try a few schedulers and choose whichever returns a lower min-memory,
// not accounting for fragmentation.
@@ -549,7 +560,7 @@ StatusOr<HloSchedule> ScheduleModule(
HloSchedule schedule(&module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
@@ -577,7 +588,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
- tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
+ absl::flat_hash_map<const HloComputation*, int64> empty_map;
return ScheduleComputationHelper(computation, *points_to_analysis,
size_function, nullptr, empty_map);
}
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 5e02868eba..a4c1d3db81 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -37,7 +38,7 @@ namespace xla {
typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
+ const absl::flat_hash_map<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
@@ -45,7 +46,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// DFS-order scheduler
@@ -53,7 +54,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// Naive Post Order scheduler
@@ -61,7 +62,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// The default scheduling algorithm. Runs both the list scheduler
@@ -71,7 +72,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// Returns an HloSchedule which seeks to minimize the memory required for
@@ -90,7 +91,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
-class HloMemoryScheduler : public HloPassInterface {
+class HloMemoryScheduler : public HloModulePass {
public:
// size_function is the function returning the number of bytes required for a
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@@ -109,7 +110,7 @@ class HloMemoryScheduler : public HloPassInterface {
// A trivial pass which clears the schedule currently set on the
// HloModule. After this pass runs HloModudle::has_schedule will return false.
-class HloDescheduler : public HloPassInterface {
+class HloDescheduler : public HloModulePass {
public:
HloDescheduler() = default;
~HloDescheduler() override = default;
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 1b9e9bfc77..214119fba8 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -146,126 +147,6 @@ ENTRY root {
instructions_by_name.at("e")));
}
-TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
- // %WhileCond (cond_param: f32[4]) -> pred[] {
- // %cond_param = f32[4]{0} parameter(0)
- // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
- // ROOT %not-equal-to = pred[] not-equal-to(
- // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
- // }
- // %WhileBody (body_param: f32[4]) -> f32[4] {
- // %body_param = f32[4]{0} parameter(0)
- // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
- // ROOT %subtract = f32[4]{0} subtract(
- // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
- // }
- // %ListAccountsForSubcomputations () -> f32[2,4] {
- // %constant.3 = f32[2,4]{1,0} constant(
- // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
- // %transpose = f32[2,4]{1,0} transpose(
- // f32[2,4]{1,0} %constant.3), dimensions={0,1}
- // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
- // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
- // condition=%WhileCond,
- // body=%WhileBody
- // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
- // ROOT %add = f32[2,4]{1,0} add(
- // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
- // }
-
- auto module = CreateNewModule();
- const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
- const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
-
- // param != 0
- // Needs 17 bytes
- auto cond_builder = HloComputation::Builder("WhileCond");
- HloInstruction* cond_param = cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1f32, "cond_param"));
- HloInstruction* zero_vector =
- cond_builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
- cond_builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
- auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
-
- // param - 1
- // Needs 16 bytes
- auto body_builder = HloComputation::Builder("WhileBody");
- HloInstruction* body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1f32, "body_param"));
- HloInstruction* one_vector =
- body_builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
- body_builder.AddInstruction(HloInstruction::CreateBinary(
- r1f32, HloOpcode::kSubtract, body_param, one_vector));
- auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
-
- // transpose(matrix) + bcast(while)
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* while_init =
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
- // Creates 16 bytes, ignoring subcomputations
- HloInstruction* while_loop =
- builder.AddInstruction(HloInstruction::CreateWhile(
- r1f32, cond_computation, body_computation, while_init));
-
- // Creates 32 bytes and frees 16
- HloInstruction* bcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
-
- HloInstruction* matrix = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
- {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
- // Creates 32 bytes
- HloInstruction* transpose = builder.AddInstruction(
- HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
-
- // Creates 32 bytes and frees 64
- HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
-
- module->AddEntryComputation(builder.Build());
-
- auto size_fn = [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- };
- TF_ASSERT_OK_AND_ASSIGN(
- HloSchedule schedule,
- ScheduleModule(*module, size_fn, ListMemoryScheduler));
- // Verify that all instructions are in the sequence.
- auto entry_computation = module->entry_computation();
- EXPECT_EQ(entry_computation->instruction_count(),
- schedule.sequence(entry_computation).size());
- SequentialHloOrdering ordering(schedule);
- // This schedule is an example of List's greedy heuristics being suboptimal.
- // The while_loop is more expensive than transpose, so it would have been
- // better to schedule it first, instead of during the busy time.
- EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop));
- EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
- EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
- EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
-
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
- memory_by_computation[cond_computation] = 17;
- memory_by_computation[body_computation] = 16;
- std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
- TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
-
- // HeapSimulator doesn't account for subcomputations
- EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, schedule.sequence(entry_computation),
- *points_to_analysis, size_fn)
- .ValueOrDie());
- // HeapSimulator accounts for subcomputations. The output buffer is aliased,
- // so we don't double count.
- EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, schedule.sequence(entry_computation),
- *points_to_analysis, size_fn, &memory_by_computation)
- .ValueOrDie());
-}
-
TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto builder = HloComputation::Builder(TestName());
const auto TUPLE_SIZE = 1;
@@ -409,7 +290,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
EXPECT_EQ(module->entry_computation()->instruction_count(),
schedule.sequence(module->entry_computation()).size());
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
memory_by_computation[body_computation] = 16;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index b3949f3a6d..93e04eb3db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -144,7 +146,8 @@ void HloModule::ReplaceComputations(
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow: {
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter: {
HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
replacements, instruction->to_apply(), nullptr);
if (new_arg != nullptr) {
@@ -285,8 +288,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
- tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
- tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
+ absl::flat_hash_map<int64, HloComputation*> computation_map;
+ absl::flat_hash_map<HloComputation*, int64> to_proto_id;
std::vector<std::unique_ptr<HloComputation>> computations;
HloComputation* entry = nullptr;
for (const HloComputationProto& computation_proto : proto.computations()) {
@@ -327,10 +330,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
- tensorflow::gtl::FlatSet<string> computation_names;
- tensorflow::gtl::FlatSet<string> instruction_names;
- tensorflow::gtl::FlatSet<int> computation_ids;
- tensorflow::gtl::FlatSet<int> instruction_ids;
+ absl::flat_hash_set<string> computation_names;
+ absl::flat_hash_set<string> instruction_names;
+ absl::flat_hash_set<int> computation_ids;
+ absl::flat_hash_set<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13781..735804e827 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@ class HloModule {
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@ class HloModule {
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@ class HloModule {
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index f7be5cae22..31d26cc51e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -50,9 +50,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
auto* while_body_root = while_body_comp->root_instruction();
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
- while_body_root->opcode() != HloOpcode::kTuple ||
- while_body_comp->HasSideEffect() ||
- xla_while->while_condition()->HasSideEffect()) {
+ while_body_root->opcode() != HloOpcode::kTuple) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
// with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 12ca2340a6..d472211d2a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -28,7 +28,7 @@ namespace xla {
// Sweeps through live instructions which cross computation boundaries (kWhile),
// and removes code at dead shape indices.
//
-class HloModuleDCE : public HloPassInterface {
+class HloModuleDCE : public HloModulePass {
public:
~HloModuleDCE() override {}
absl::string_view name() const override { return "hlo-module-dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 9c01862a4b..b4aac4c807 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
}
/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
-HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
+HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) {
auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
TF_RETURN_IF_ERROR(metadata->Build());
return std::move(metadata);
@@ -392,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- absl::make_unique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::vector<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
- companion_set->insert(instruction1);
- companion_set->insert(instruction2);
+ companion_set->push_back(instruction1);
+ companion_set->push_back(instruction2);
companion_set_index_[instruction1] = companion_sets_.size() - 1;
companion_set_index_[instruction2] = companion_sets_.size() - 1;
} else if (!ContainsKey(companion_set_index_, instruction1)) {
- companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+ companion_sets_[companion_set_index_[instruction2]]->push_back(
+ instruction1);
companion_set_index_[instruction1] = companion_set_index_[instruction2];
} else if (!ContainsKey(companion_set_index_, instruction2)) {
- companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+ companion_sets_[companion_set_index_[instruction1]]->push_back(
+ instruction2);
companion_set_index_[instruction2] = companion_set_index_[instruction1];
} else if (companion_set_index_[instruction1] !=
companion_set_index_[instruction2]) {
- companion_sets_[companion_set_index_[instruction1]]->insert(
- Companions(instruction2).begin(), Companions(instruction2).end());
+ // At any point while building the companion sets, each instruction belongs
+ // to at most 1 companion set, so the union of two companion sets is
+ // concatenating two disjoint sets.
+ absl::c_copy(Companions(instruction2),
+ std::back_inserter(
+ *companion_sets_[companion_set_index_[instruction1]]));
int64 index_to_remove = companion_set_index_[instruction2];
for (HloInstruction* hlo : Companions(instruction2)) {
companion_set_index_[hlo] = companion_set_index_[instruction1];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 768b0c7eb3..928df0f5a7 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -102,14 +102,14 @@ class HloModuleGroupMetadata {
HloInstruction* recv_done = nullptr;
};
- explicit HloModuleGroupMetadata(const std::vector<HloModule*>& modules)
- : modules_(modules) {}
+ explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules)
+ : modules_(modules.begin(), modules.end()) {}
~HloModuleGroupMetadata() = default;
// Build and return the metadata for the given modules.
static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
- const std::vector<HloModule*>& modules);
+ absl::Span<HloModule* const> modules);
// Returns true if the instruction is one of the 4 channel instructions (Send,
// Recv, SendDone, RecvDone).
@@ -169,14 +169,14 @@ class HloModuleGroupMetadata {
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
- const std::unordered_set<HloInstruction*>& Companions(
+ const std::vector<HloInstruction*>& Companions(
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
// Returns the companion set at the given index.
- const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+ const std::vector<HloInstruction*>& companion_set(int64 index) const {
CHECK_LT(index, companion_sets_.size());
return *companion_sets_[index];
}
@@ -187,7 +187,7 @@ class HloModuleGroupMetadata {
}
// Returns the list of all companion sets in the HLO module group.
- const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+ const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets() const {
return companion_sets_;
}
@@ -247,37 +247,36 @@ class HloModuleGroupMetadata {
void DumpCollectedStats() const;
// List of all companion instructions sets in the module.
- std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
- companion_sets_;
+ std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
+ absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_;
// Map from computation to the instruction using it (a kWhile, kConditional).
- tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
+ absl::flat_hash_map<const HloComputation*, TrackedInstruction>
tracked_instructions_;
// Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
// communicating instructions within the proper called computation(s).
- tensorflow::gtl::FlatMap<HloInstruction*, std::vector<HloInstruction*>>
+ absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
tracked_instructions_comms_;
// All channels in the module.
std::vector<Channel> channels_;
// Map from channel ids to the index in channels_.
- tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+ absl::flat_hash_map<int64, int64> channel_id_map_;
// Map from all-reduce ids to the all reduce instructions.
- tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_;
+ absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_;
// The maximum channel id used in the module group.
int64 max_channel_id_ = -1;
// The modules that this metadata was built from.
- const std::vector<HloModule*>& modules_;
+ const std::vector<HloModule*> modules_;
- tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
+ absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
points_to_analyses_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
index ebf790ba6f..b7b12cb72b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -137,6 +138,69 @@ ENTRY %entry (a: f32[]) -> f32[] {
::testing::ElementsAre(op::Parameter()));
}
+// Tests that the order of companion instructions in the companion set doesn't
+// change across runs.
+TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
+ // A simple while loop template for core i sending to core i+1.
+ constexpr char text[] = R"(
+HloModule module_%d
+
+while_cond {
+ ROOT p = pred[] constant(true)
+}
+
+while_body {
+ param = s32[] parameter(0)
+ token.s = token[] after-all()
+ token.r = token[] after-all()
+ send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
+ send-done = token[] send-done(send), channel_id=%d
+ recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
+ ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
+}
+
+ENTRY entry {
+ while_init = s32[] constant(1)
+ ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
+}
+)";
+
+ // Try creating the module and the metadata kTrialCount times and check the
+ // companion instructions remain in the same order.
+ const int64 kTrialCount = 5;
+ const int64 kDeviceCount = 10;
+ std::vector<int64> companion_order;
+
+ for (int64 t = 0; t < kTrialCount; ++t) {
+ HloModuleGroup group(TestName());
+ for (int64 i = 0; i < kDeviceCount; ++i) {
+ const int64 send_channel = i;
+ const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
+ recv_channel, recv_channel)));
+ group.push_back(std::move(module));
+ }
+ ASSERT_EQ(group.modules().size(), kDeviceCount);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto metadata,
+ HloModuleGroupMetadata::Build(group.modules()));
+ ASSERT_EQ(metadata->companion_sets().size(), 1);
+
+ std::vector<int64> module_ids;
+ for (HloInstruction* companion : *metadata->companion_sets()[0]) {
+ module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
+ }
+
+ if (t == 0) {
+ companion_order = module_ids;
+ } else {
+ EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index d83ee71490..fddeb5f0a2 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
predecessors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique predecessors list; if the predecessors is a companion
// instruction, also add companion instructions; if the predecessors is a
@@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
successors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique successors list; if the successor is a companion
// instruction, also add companion instructions; if the successor is a
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
index 309c23045d..f21b44bcd9 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -87,7 +87,7 @@ class HloModuleGroupUtil {
// * visit_state: map from each instruction to its visit state.
// * visit_function: function called when each instruction group.
// * root: the root instruction of the traversal.
- using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>;
+ using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>;
Status VisitTopologicalOrder(VisitStates* visit_state,
const VisitFunction& visit_function,
HloInstruction* root);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 2d4e38589f..4551a1c2e2 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) {
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
- static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
+ static auto* opcode_map = new absl::flat_hash_map<string, HloOpcode>({
#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
{opcode_name, HloOpcode::enum_name},
HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index f1dc08bafa..23d41d91d6 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
}
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
- // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
- // is live into the module.
+ // Entry parameter should always be defined before other instructions.
const HloModule* module = b.defining_instruction()->parent()->parent();
if (b.defining_instruction()->parent() == module->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
+ if (a.defining_instruction()->parent() == module->entry_computation() &&
+ a.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ return true;
+ }
+
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
@@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b0361c3f02..66313492eb 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering {
// predecessors. An instruction is an element of its own predecessor set.
//
// Subclasses should fill this in to define the desired ordering.
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::unique_ptr<HloReachabilityMap>>
+ absl::flat_hash_map<const HloComputation*,
+ std::unique_ptr<HloReachabilityMap>>
predecessors_;
};
@@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering {
// this map so more than one instruction may have the same position
// value. This is not a problem because ExecutesBefore also verifies
// instructions are in the same computation.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
+ absl::flat_hash_map<const HloInstruction*, int> order_position_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 00970bcda3..b045adc964 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
+ // Entry parameter should always be defined before other instruction.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ module->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ DependencyHloOrdering ordering(module.get());
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
+ dataflow->GetValueDefinedAt(constant)));
+ EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(param)));
+}
+
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code:
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89c54..dd62988bcc 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ Status Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -82,28 +79,37 @@ class HloParser {
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
- // Stand-alone parsing utility for a single instruction worth of text.
- Status ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name);
-
private:
- // Locates an instruction with the given name in the instruction_pool_ or
+ using InstrNameTable =
+ std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
+
+ // Returns the map from the instruction name to the instruction itself and its
+ // location in the current scope.
+ InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
+
+ // Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
- // If the missing_instruction_hook_ is registered and a "shape" is provided,
- // the hook will be called and may satisfy the request for the given
- // instruction. This is useful when we reify parameters as they're resolved;
- // i.e. for ParseSingleInstruction.
+ // When the name is not found or name is empty, if create_missing_instruction_
+ // hook is registered and a "shape" is provided, the hook will be called to
+ // create an instruction. This is useful when we reify parameters as they're
+ // resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
+ // Parse a single instruction worth of text.
+ bool ParseSingleInstruction(HloModule* module);
+
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
- bool ParseInstructionList(HloComputation::Builder* builder,
- string* root_name);
+ bool ParseInstructionList(HloComputation** computation,
+ const string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name,
+ LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
@@ -284,25 +290,47 @@ class HloParser {
bool AddComputation(const string& name, HloComputation* computation,
LocTy name_loc);
- // The map from the instruction/computation name to the
- // instruction/computation itself and it's location. This does not own the
- // pointers.
- std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
- instruction_pool_;
+ HloLexer lexer_;
+
+ // A stack for the instruction names. The top of the stack stores the
+ // instruction name table for the current scope.
+ //
+ // A instruction's name is unique among its scope (i.e. its parent
+ // computation), but it's not necessarily unique among all computations in the
+ // module. When there are multiple levels of nested computations, the same
+ // name could appear in both an outer computation and an inner computation. So
+ // we need a stack to make sure a name is only visible within its scope,
+ std::vector<InstrNameTable> scoped_name_tables_;
+
+ // A helper class which pushes and pops to an InstrNameTable stack via RAII.
+ class Scope {
+ public:
+ explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
+ : scoped_name_tables_(scoped_name_tables) {
+ scoped_name_tables_->emplace_back();
+ }
+ ~Scope() { scoped_name_tables_->pop_back(); }
+
+ private:
+ std::vector<InstrNameTable>* scoped_name_tables_;
+ };
+
+ // Map from the computation name to the computation itself and its location.
std::unordered_map<string, std::pair<HloComputation*, LocTy>>
computation_pool_;
- HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
- // Function that gets invoked when we try to resolve an instruction
- // instruction_pool_ but fail to do so.
- std::function<std::pair<HloInstruction*, LocTy>*(string,
- const optional<Shape>&)>
- missing_instruction_hook_;
+ // When an operand name cannot be resolved, this function is called to create
+ // a parameter instruction with the given name and shape. It registers the
+ // name, instruction, and a placeholder location in the name table. It returns
+ // the newly-created instruction and the placeholder location. If `name` is
+ // empty, this should create the parameter with a generated name. This is
+ // supposed to be set and used only in ParseSingleInstruction.
+ std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
+ const Shape& shape)>
+ create_missing_instruction_;
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
@@ -349,24 +377,50 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+Status HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ if (lexer_.GetKind() == TokKind::kw_HloModule) {
+ // This means that the text contains a full HLO module.
+ if (!ParseHloModule(module)) {
+ return InvalidArgument(
+ "Syntax error when trying to parse the text as a HloModule:\n%s",
+ GetError());
+ }
+ return Status::OK();
+ }
+ // This means that the text is a single HLO instruction.
+ if (!ParseSingleInstruction(module)) {
+ return InvalidArgument(
+ "Syntax error when trying to parse the text as a single "
+ "HloInstruction:\n%s",
+ GetError());
+ }
+ return Status::OK();
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
const string& name, const optional<Shape>& shape) {
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = nullptr;
+ if (!name.empty()) {
+ instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
+ }
+
// Potentially call the missing instruction hook.
- if (instr == nullptr && missing_instruction_hook_ != nullptr) {
- return missing_instruction_hook_(name, shape);
+ if (instr == nullptr && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ "Operand had no shape in HLO text; cannot create parameter for "
+ "single-instruction module.");
+ return nullptr;
+ }
+ return create_missing_instruction_(name, *shape);
}
return instr;
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +439,20 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +468,20 @@ bool HloParser::ParseComputations() {
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -447,7 +498,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -455,40 +505,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- string root_name;
- if (!ParseInstructionList(builder.get(), &root_name)) {
+ HloComputation* computation = nullptr;
+ if (!ParseInstructionList(&computation, name)) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
- // This means some instruction was marked as ROOT but we didn't find it in the
- // pool, which should not happen.
- if (!root_name.empty() && root_node == nullptr) {
- LOG(FATAL) << "instruction " << root_name
- << " was marked as ROOT but the parser has not seen it before";
- }
-
- HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
- // Now root can be either an existing instruction or a nullptr. If it's a
- // nullptr, the implementation of Builder will set the last instruction as
- // root instruction.
- computations_.emplace_back(builder->Build(root));
- HloComputation* computation = computations_.back().get();
-
- if (!root) {
- root = computation->root_instruction();
- } else {
- CHECK_EQ(root, computation->root_instruction());
- }
-
// If param_list_to_shape was present, check compatibility.
- if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
+ if (shape_loc != nullptr &&
+ !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
- StrCat("Shape of computation ", name, ", ",
- ShapeUtil::HumanString(shape),
- ", is not compatible with that of its root instruction ",
- root_name, ", ", ShapeUtil::HumanString(root->shape())));
+ StrCat(
+ "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
+ ", is not compatible with that of its root instruction ",
+ computation->root_instruction()->name(), ", ",
+ ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
@@ -497,43 +528,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
}
*entry_computation = computation;
}
- instruction_pool_.clear();
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
-bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
- string* root_name) {
+bool HloParser::ParseInstructionList(HloComputation** computation,
+ const string& computation_name) {
+ Scope scope(&scoped_name_tables_);
+ HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
+ string root_name;
do {
- if (!ParseInstruction(builder, root_name)) {
+ if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
- return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of instruction list.");
+ if (!ParseToken(TokKind::kRbrace,
+ "expects '}' at the end of instruction list.")) {
+ return false;
+ }
+ HloInstruction* root = nullptr;
+ if (!root_name.empty()) {
+ std::pair<HloInstruction*, LocTy>* root_node =
+ tensorflow::gtl::FindOrNull(current_name_table(), root_name);
+
+ // This means some instruction was marked as ROOT but we didn't find it in
+ // the pool, which should not happen.
+ if (root_node == nullptr) {
+ LOG(FATAL) << "instruction " << root_name
+ << " was marked as ROOT but the parser has not seen it before";
+ }
+ root = root_node->first;
+ }
+
+ // Now root can be either an existing instruction or a nullptr. If it's a
+ // nullptr, the implementation of Builder will set the last instruction as
+ // the root instruction.
+ computations_.emplace_back(builder.Build(root));
+ *computation = computations_.back().get();
+ return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
- Shape shape;
- HloOpcode opcode;
- std::vector<HloInstruction*> operands;
-
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
- !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
- !ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
@@ -544,6 +594,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
*root_name = name;
}
+ return ParseInstruciontRhs(builder, name, name_loc);
+}
+
+bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
+ const string& name, LocTy name_loc) {
+ Shape shape;
+ HloOpcode opcode;
+ std::vector<HloInstruction*> operands;
+
+ if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ return false;
+ }
+
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
@@ -1274,11 +1337,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
+ optional<string> opaque;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
+ attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
@@ -1287,8 +1352,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
- shape, operands, *custom_call_target));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
if (window.has_value()) {
instruction->set_window(*window);
}
@@ -2151,7 +2217,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
}
}
if (!ParseName(&name)) {
- return false;
+ // When parsing a single instruction (as opposed to a whole module), an
+ // HLO may have one or more operands with a shape but no name:
+ //
+ // foo = add(f32[10], f32[10])
+ //
+ // create_missing_instruction_ is always non-null when parsing a single
+ // instruction, and is responsible for creating kParameter instructions
+ // for these operands.
+ if (shape.has_value() && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ name = "";
+ } else {
+ return false;
+ }
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
@@ -2304,9 +2383,17 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kHloComputation: {
- HloComputation* result;
- if (!ParseComputationName(&result)) {
- return false;
+ HloComputation* result = nullptr;
+ if (lexer_.GetKind() == TokKind::kLbrace) {
+ // This means it is a nested computation.
+ if (!ParseInstructionList(&result, /*computation_name=*/"_")) {
+ return false;
+ }
+ } else {
+ // This means it is a computation name.
+ if (!ParseComputationName(&result)) {
+ return false;
+ }
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
@@ -3139,7 +3226,7 @@ bool HloParser::EatIfPresent(TokKind kind) {
bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc) {
- auto result = instruction_pool_.insert({name, {instruction, name_loc}});
+ auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
@@ -3209,91 +3296,96 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
return padding_config;
}
-Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name) {
- TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+bool HloParser::ParseSingleInstruction(HloModule* module) {
+ if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
+ LOG(FATAL) << "Parser state is not clean. Please do not call any other "
+ "methods before calling ParseSingleInstruction.";
+ }
+ HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
- missing_instruction_hook_ =
- [this, builder, &parameter_count](
- string name,
- const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
- if (!shape.has_value()) {
- Error(lexer_.GetLoc(),
- StrCat("Operand ", name,
- " had no shape in HLO text; cannot create parameter for "
- "single-instruction module."));
- return nullptr;
- }
- HloInstruction* parameter = builder->AddInstruction(
- HloInstruction::CreateParameter(parameter_count++, *shape, name));
- instruction_pool_[name] = {parameter, lexer_.GetLoc()};
- return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ create_missing_instruction_ =
+ [this, &builder, &parameter_count](
+ const string& name,
+ const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
+ string new_name = name.empty() ? StrCat("_", parameter_count) : name;
+ HloInstruction* parameter = builder.AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, shape, new_name));
+ current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
- // Prime the lexer.
- lexer_.Lex();
-
// Parse the instruction with the registered hook.
- if (!ParseInstruction(builder, root_name)) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ Scope scope(&scoped_name_tables_);
+ if (CanBeShape()) {
+ // This means that the instruction's left-hand side is probably omitted,
+ // e.g.
+ //
+ // f32[10] fusion(...), calls={...}
+ if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) {
+ return false;
+ }
+ } else {
+ // This means that the instruction's left-hand side might exist, e.g.
+ //
+ // foo = f32[10] fusion(...), calls={...}
+ string root_name;
+ if (!ParseInstruction(&builder, &root_name)) {
+ return false;
+ }
}
- return Status::OK();
+
+ module->AddEntryComputation(builder.Build());
+ for (auto& comp : computations_) {
+ module->AddEmbeddedComputation(std::move(comp));
+ }
+ return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError());
- }
- return parser.ConsumeHloModule();
+ auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
+ HloParser parser(str);
+ TF_RETURN_IF_ERROR(parser.Run(module.get()));
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"_", HloModuleConfig());
+ HloParser parser(str);
+ TF_RETURN_IF_ERROR(parser.Run(module.get()));
+ return std::move(module);
}
-StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
- auto builder = absl::make_unique<HloComputation::Builder>(string(name));
- string root_name;
- TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
- std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
- module->AddEntryComputation(std::move(computation));
- return std::move(module);
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ TF_RETURN_IF_ERROR(parser.Run(module));
+ return Status::OK();
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a184da..81eeb9f13b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,18 @@ namespace xla {
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
-// Parses the text for a single HLO operation into an HLO module with a function
-// that runs that operation (with the same parameters) as its entry computation.
-StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- absl::string_view str, absl::string_view name = "single_op");
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index cca50fab54..255123d331 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1004,6 +1004,18 @@ ENTRY CustomCall {
)"
},
+// CustomCall with opaque value.
+{
+"CustomCallWithOpaque",
+R"(HloModule custom_call
+
+ENTRY CustomCall {
+ constant = f32[1]{0} constant({12345})
+ ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
+}
+
+)"
+},
// Variables with non-default names
{
"NonDefaultNames",
@@ -1151,49 +1163,80 @@ ENTRY Sort {
// clang-format on
}
-class HloParserTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestData> {
+// The test class for those tests defined above which round-trip through the
+// parser and ToString is templatized on two bool parameters:
+//
+// short_form : used for the "short" test cases which use the ShortParsable
+// output form.
+// proto_round_trip : whether the module should also be round-tripped through
+// HloProto form. This provides much better coverage for the proto
+// serialization/deserialization.
+//
+// The proto_round_trip=true case also technically covers the Parser->ToString
+// roundtrip as well, but separating out the Parser->ToString roundtrip as its
+// own test provides better isolation and could conceivably catch weirdo bugs
+// which are hidden by interaction between the textual and proto roundtripping.
+template <bool short_form, bool proto_round_trip>
+class HloParameterizedParserTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(string_view s, string_view expected) {
- EXPECT_TRUE(absl::StrContains(s, expected))
- << "'" << s << "' does not contain '" << expected << "'";
- }
-
// Expects "ToString(ParseHloString(string)) == string", that is, parses the
// string, asserts that it succeeded, stringifies the parsed module, and
// checks that the it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original, result.ValueOrDie()->ToString(
- HloPrintOptions().set_print_large_constants(true)));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(original));
+ if (proto_round_trip) {
+ TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
+ module->ToProto(), module->config()));
+ }
+ if (short_form) {
+ EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
+ } else {
+ EXPECT_EQ(
+ original,
+ module->ToString(HloPrintOptions().set_print_large_constants(true)));
+ }
}
};
-class HloParserShortTest : public HloParserTest {
- protected:
- void ExpectEqualShort() {
- const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original,
- result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
- }
-};
+// These using shenanigans are required because the TEST_P macro doesn't like
+// template instantiations which contain commas.
+using HloParserTestLong = HloParameterizedParserTest<false, false>;
+using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
+using HloParserTestShort = HloParameterizedParserTest<true, false>;
+using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
-TEST_P(HloParserTest, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
-TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestLongProto,
+ ::testing::ValuesIn(CreateTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
+ ::testing::ValuesIn(CreateShortTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestShortProto,
::testing::ValuesIn(CreateShortTestCases()),
TestDataToString);
+class HloParserTest : public ::testing::Test {
+ protected:
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+ }
+};
+
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = ParseHloString(original);
@@ -1261,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
- %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
+ %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
@@ -1720,6 +1763,25 @@ ENTRY entry {
"was parsing 8:39: error: instruction does not exist: aparam");
}
+TEST_F(HloParserTest, SameNameDiffComputations) {
+ const string original = R"(HloModule same_names:
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT result = f32[] add(p0, p1)
+}
+
+ENTRY ReduceR3ToR2 {
+ p0 = f32[8,16,256]{2,1,0} parameter(0)
+ p1 = f32[] constant(0)
+ ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
+ ASSERT_NE(module->entry_computation(), nullptr);
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
TEST_F(HloParserTest, ParseSharding) {
const string original = "{maximal device=42}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
@@ -1773,27 +1835,142 @@ TEST(HloParserSingleOpTest, SingleOp) {
const string text =
"%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
"f32[2,4]{1,0} %x)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Parameter(0), op::Parameter(1)));
}
-TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
+TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
+ const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
+ ASSERT_TRUE(!module.status().ok());
+ LOG(INFO) << "Status: " << module.status();
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("expects '=' in instruction"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
- StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
- EXPECT_THAT(
- module.status().ToString(),
- ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("Operand had no shape in HLO text"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoNames) {
+ const string text =
+ "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, CanonicalOp) {
+ const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
+ const string text =
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+}, body=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested) {
+ const string text =
+ R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
+{
+ %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
+ %param_1 = f32[2]{0} parameter(1)
+ %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
+ ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Fusion(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloString(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("does not exist: x"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloString(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[], f32[])
+})";
+ auto status = ParseHloString(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
const string text =
R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index f1ad0f9b01..fdaac34386 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,15 +26,45 @@ limitations under the License.
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
-// organize a sequence of passes.
+// organize a sequence of passes. An HLO pass should not extend this class
+// directly; it should extend HloModulePass or HloModuleGroupPass.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual absl::string_view name() const = 0;
- // Run the pass on the given HLO module. Return whether it modified the
+ // Run the pass on the given HLO module. Returns whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+ // Run the pass on the given HLO module group. Returns whether it modified the
+ // module group. Ideally, the module group variant would be named "Run" as
+ // well, but C++ does not handle overloaded virtual methods well.
+ virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
+};
+
+// Base class for passes which are module-scoped.
+class HloModulePass : public HloPassInterface {
+ public:
+ // Runs the pass on a module group by iterating through each module in the
+ // group.
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
+ changed |= module_changed;
+ }
+ return changed;
+ };
+};
+
+// Base class for passes which are module-group scoped. These passes cannot run
+// on an HLO module.
+class HloModuleGroupPass : public HloPassInterface {
+ public:
+ StatusOr<bool> Run(HloModule* module) override {
+ return InternalError("Module group pass cannot be run on a module");
+ }
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e4ed0de62..5e004ce78a 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,8 @@ limitations under the License.
#include <functional>
-#include "absl/strings/str_cat.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -25,112 +26,131 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-namespace {
-using absl::StrAppend;
-using absl::StrCat;
-
-void DumpModuleGraph(const HloModule& module, const string& message) {
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
- VLOG(3) << "HLO " << message << ":";
- XLA_VLOG_LINES(3, module.ToString());
+template <typename HloT>
+Status HloPassPipeline::RunInvariantCheckers(
+ HloT* hlo, absl::string_view after_pass_name) {
+ for (auto& invariant_checker : invariant_checkers_) {
+ VLOG(1) << " Invariant checker " << invariant_checker->name();
+ StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
+ VLOG(1) << " Invariant checker done " << invariant_checker->name();
+ if (!changed_status.ok()) {
+ VLOG(2) << "Failed invariant check:";
+ XLA_VLOG_LINES(2, hlo->ToString());
+ return Status(changed_status.status().code(),
+ absl::StrCat(changed_status.status().error_message(),
+ "\n\nFailed after ", after_pass_name));
+ }
+ TF_RET_CHECK(!changed_status.ValueOrDie())
+ << "invariant checkers must not change the graph";
+ }
+ return Status::OK();
}
-void DumpModuleProto(const HloModule& module, const string& dump_to,
- const string& pipeline_name, const string& pass_name) {
- static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
- static auto* const module_id_to_pass_number =
- new tensorflow::gtl::FlatMap<int64, int64>();
-
- tensorflow::mutex_lock lock(mu);
- const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+template <typename HloT>
+StatusOr<bool> HloPassPipeline::RunPassesInternal(
+ HloT* hlo, absl::Span<HloPassInterface* const> passes) {
+ string last_pass_name = "pipeline-start";
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
+ bool changed = false;
+ for (HloPassInterface* pass : passes) {
+ VLOG(1) << " HLO pass " << pass->name();
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/pass->name());
+ TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+ changed |= pass_changed;
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
+ last_pass_name = string(pass->name());
+ }
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/"pipeline-end");
+ return changed;
+}
- const string mod_name = SanitizeFileName(
- absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
- pass_number, pipeline_name, pass_name));
+std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
+ const DebugOptions& debug_options) {
+ auto repeated_field = debug_options.xla_disable_hlo_passes();
+ absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
+ if (!disabled_pass_names.empty()) {
+ VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
+ << absl::StrJoin(disabled_pass_names, ", ");
+ }
- TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
- dump_to, mod_name));
+ std::vector<HloPassInterface*> enabled_passes;
+ for (auto& pass : passes_) {
+ if (disabled_pass_names.count(string(pass->name())) == 0) {
+ enabled_passes.push_back(pass.get());
+ }
+ }
+ return enabled_passes;
}
-} // namespace
-StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
- run_called_ = true;
+void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ const string& proto_dump_path =
+ module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
+ if (!proto_dump_path.empty()) {
+ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+ static auto* const module_id_to_pass_number =
+ new absl::flat_hash_map<int64, int64>();
+
+ tensorflow::mutex_lock lock(mu);
+ const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+
+ const string filename = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, name(), after_pass_name));
+
+ TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
+ MakeHloProto(module), proto_dump_path, filename));
+ }
- VLOG(1) << "Running HLO pass pipeline " << name();
+ const string message =
+ StrCat("after ", after_pass_name, ", before ", before_pass_name);
+ hlo_graph_dumper::MaybeDumpHloModule(module, message);
+ VLOG(3) << "HLO " << message << ":";
+ XLA_VLOG_LINES(3, module.ToString());
+}
- auto repeated_field =
- module->config().debug_options().xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
- repeated_field.end());
- if (!disabled_passes.empty()) {
- VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << absl::StrJoin(disabled_passes, ", ");
+void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ for (const HloModule* module : module_group.modules()) {
+ MaybeDumpHlo(*module, after_pass_name, before_pass_name);
}
+}
- auto run_invariant_checkers = [this,
- module](const string& message) -> Status {
- for (auto& invariant_checker : invariant_checkers_) {
- VLOG(1) << " Invariant checker " << invariant_checker->name();
- StatusOr<bool> changed_status = invariant_checker->Run(module);
- VLOG(1) << " Invariant checker done " << invariant_checker->name();
- if (!changed_status.ok()) {
- VLOG(2) << "Module failed invariant check:";
- XLA_VLOG_LINES(2, module->ToString());
- return Status(changed_status.status().code(),
- StrCat(changed_status.status().error_message(),
- "\n\nFailed ", message));
- }
- TF_RET_CHECK(!changed_status.ValueOrDie())
- << "invariant checkers must not change the graph";
- }
- return Status::OK();
- };
+StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ run_called_ = true;
- string prefix = StrCat(name(), ": pipeline start");
- bool changed = false;
- string message;
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("before running pipeline: ", name())));
- const string xla_dump_per_pass_hlo_proto_to =
- module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- "pipeline_start");
- }
+ VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
+ << name();
- for (auto& pass : passes_) {
- if (disabled_passes.count(string(pass->name())) > 0) {
- VLOG(1) << " Skipping HLO pass " << pass->name()
- << ", disabled by --xla_disable_hlo_passes";
- continue;
- }
+ return RunPassesInternal(module,
+ GetEnabledPasses(module->config().debug_options()));
+}
- VLOG(1) << " HLO pass " << pass->name();
+StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
+ run_called_ = true;
- // Emit label containing: "after foo-pass, before bar-pass".
- message.clear();
- StrAppend(&message, prefix, ", before ", pass->name());
- DumpModuleGraph(*module, message);
-
- TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("after running pass: ", pass->name())));
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- string(pass->name()));
- }
+ VLOG(1) << "Running HLO pass pipeline on module group "
+ << module_group->name() << ": " << name();
- changed |= changed_this_pass;
- prefix.clear();
- StrAppend(&prefix, name(), ": after ", pass->name());
+ if (module_group->modules().empty()) {
+ VLOG(1) << "Module group is empty. Nothing to do.";
+ return false;
}
- DumpModuleGraph(*module, prefix + ", pipeline end");
- return changed;
+
+ return RunPassesInternal(
+ module_group,
+ GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4dac1..09e7033ea4 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
- // Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
+ // Returns the set of passes which are enabled. DebugOptions can selectively
+ // disable passes via --xla_disable_hlo_passes flag.
+ std::vector<HloPassInterface*> GetEnabledPasses(
+ const DebugOptions& debug_options);
+
+ // Maybe dumps the given module or module group depending on flag values
+ // contained in DebugOptions of module config.
+ void MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+ void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+
+ // Runs the invariant checker on the given HLO. HloT can be either HloModule
+ // or HloModuleGroup.
+ template <typename HloT>
+ Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+ // Helper which runs the given pass on the given HLO. HloT can be either
+ // HloModule or HloModuleGroup.
+ template <typename HloT>
+ StatusOr<bool> RunPassesInternal(HloT* hlo,
+ absl::Span<HloPassInterface* const> passes);
+
+ // Helpers which run the given passes on the given HLO construct. These
+ // helpers enable templating of the core of the pipeline logic by providing
+ // HloModule and HloModuleGroup specific methods with the same name.
+ static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+ return pass->Run(module);
+ }
+ static StatusOr<bool> RunHelper(HloPassInterface* pass,
+ HloModuleGroup* module_group) {
+ return pass->RunOnModuleGroup(module_group);
+ }
+
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
new file mode 100644
index 0000000000..ee8cb12b23
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloPassPipelineTest : public HloVerifiedTestBase {
+ protected:
+ StatusOr<HloModuleGroup> ParseModuleGroup(
+ absl::Span<const string> hlo_strings) {
+ HloModuleGroup group(TestName());
+ for (const string& hlo_string : hlo_strings) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ group.push_back(std::move(module));
+ }
+ return std::move(group);
+ }
+};
+
+// A module pass which renames instructions named 'foo' to 'bar'.
+class FooToBarModulePass : public HloModulePass {
+ absl::string_view name() const override { return "foo2bar"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "foo") {
+ instruction->SetAndSanitizeName("bar");
+ changed = true;
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// A module group pass which renames instructions named 'baz' to 'qux'.
+class BazToQuxModuleGroupPass : public HloModuleGroupPass {
+ absl::string_view name() const override { return "baz2qux"; }
+
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "baz") {
+ instruction->SetAndSanitizeName("qux");
+ changed = true;
+ }
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// An invariant checker pass which returns an error if there exists an
+// instruction named 'bar'.
+class BarBlowerUpper : public HloModulePass {
+ absl::string_view name() const override { return "bar-blower-upper"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "bar") {
+ return InternalError("Module has instruction named bar");
+ }
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(HloPassPipelineTest, ModulePassChanged) {
+ // Test an HLO module pass which changes a module.
+ const string module_str = R"(
+HloModule ModulePassChanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "foo");
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(root->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
+ // Test an HLO module pass which does not change a module.
+ const string module_str = R"(
+HloModule ModulePassUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT blahblah = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(HloPassPipelineTest, MixedPipeline) {
+ // Test a pipeline with both a module pass and a module group pass.
+ const string module_0_str = R"(
+HloModule MixedPipeline.1
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT baz = f32[] multiply(a, b)
+}
+)";
+ const string module_1_str = R"(
+HloModule MixedPipeline.0
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
+ ParseModuleGroup({module_0_str, module_1_str}));
+
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root0 =
+ module_group.module(0).entry_computation()->root_instruction();
+ HloInstruction* root1 =
+ module_group.module(1).entry_computation()->root_instruction();
+ EXPECT_EQ(root0->name(), "baz");
+ EXPECT_EQ(root1->name(), "foo");
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ pipeline.RunOnModuleGroup(&module_group));
+ EXPECT_TRUE(changed);
+
+ EXPECT_EQ(root0->name(), "qux");
+ EXPECT_EQ(root1->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, InvariantChecker) {
+ const string module_str = R"(
+HloModule InvariantChecker
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ {
+ // Run a pipeline with just the invariant checker. It should not fail
+ // because there is no 'bar' instruction in the module.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+ }
+
+ {
+ // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
+ // which fails if there is an instruction named 'bar'.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after foo2bar"));
+ }
+
+ {
+ // Run the invariant-checker only pipeline again. It should fail this time.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after pipeline-start"));
+ }
+}
+
+TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
+ // Running a module group pass on a module should produce an error.
+ const string module_str = R"(
+HloModule ModuleGroupPassOnModule
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Module group pass cannot be run on a module"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index b66a2aa4bd..5a5f01f8fd 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -19,11 +19,11 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -154,7 +154,7 @@ class HloReachabilityMap {
// Dense assignment from HloInstruction* to number. These numbers index
// into the bit_vectors_ vector and into the bits within a BitVector.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_;
+ absl::flat_hash_map<const HloInstruction*, int> indices_;
// Bitvectors holding the reachability to each instruction. The bit vector for
// instruction X includes ones for each instruction which X is reachable from.
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index bd6dd79b67..5ac43808ee 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -75,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
// cache before, and eventually calling the IsRematerializable() API.
bool CanBeRematerialized(
const HloInstruction* instruction,
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
auto it = remat_able->find(instruction);
if (it != remat_able->end()) {
return it->second;
@@ -268,7 +270,7 @@ class InstructionList {
Item* first_;
// Item for each instruction.
- tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
+ absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
};
// Return the items which use the given LogicalBuffer. Sets
@@ -503,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker(
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
- tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
+ absl::flat_hash_map<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id;
for (auto* item = instruction_list_.first(); item != nullptr;
@@ -854,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction,
Item* PickRematerializationCandidate(
const MemoryUsageTracker& memory_tracker,
const InstructionList& instruction_list, int64 memory_limit_bytes,
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
@@ -980,10 +982,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// rematerialization is essentially a move). If the next rematerialization of
// the instruction is also a move then the rematerialization is added to the
// blacklist.
- tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
+ absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
// The map from instructions to their rematerializable status.
- tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able;
+ absl::flat_hash_map<const HloInstruction*, bool> remat_able;
// The peak memory of the computation at any point in the instruction
// sequence.
@@ -1198,6 +1200,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
<< HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+ // Initialize pass object state.
+ computation_peak_memory_.clear();
+ rematerialized_computations_.clear();
+ instructions_rematerialized_ = 0;
+ net_instructions_added_ = 0;
+
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index e2aaf18b3e..70d83c04f0 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -15,6 +15,8 @@
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -33,7 +35,7 @@ namespace xla {
// CSE will undo the effects of this optimization and should not be run after
// this pass. In general, this pass should be run very late, immediately before
// code generation.
-class HloRematerialization : public HloPassInterface {
+class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -115,14 +117,13 @@ class HloRematerialization : public HloPassInterface {
// computations called from sequential context
// (CallContext::kSequential). These values are updated as rematerialization
// occurs.
- tensorflow::gtl::FlatMap<const HloComputation*, int64>
- computation_peak_memory_;
+ absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
// Set of computations which have had rematerialization
// applied. Rematerialization is only applied once per computation.
- tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
+ absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
// Count of the total instructions rematerialized.
int64 instructions_rematerialized_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index 3fc5dbeb02..9972eb2077 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -30,7 +32,7 @@ namespace xla {
/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
const HloModule* module, const HloScheduleProto& proto) {
- tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ absl::flat_hash_map<int64, const HloComputation*> id_to_computation;
for (const HloComputation* computation : module->computations()) {
id_to_computation[computation->unique_id()] = computation;
}
@@ -44,7 +46,7 @@ namespace xla {
<< "No computation exists in HLO module with id " << computation_id;
const HloComputation* computation = comp_it->second;
- tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
id_to_instruction[instruction->unique_id()] = instruction;
}
@@ -112,13 +114,13 @@ Status HloSchedule::UpdateComputationSchedule(
const HloComputation* computation) {
// Map from unique ID to HloInstruction pointer for instructions in the
// computation.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
}
// Set of all HloInstructions in the schedule.
- tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ absl::flat_hash_set<int> ids_in_schedule;
for (int id : sequences_.at(computation->unique_id()).ids()) {
InsertOrDie(&ids_in_schedule, id);
}
@@ -126,15 +128,13 @@ Status HloSchedule::UpdateComputationSchedule(
// Map from HloInstruction X to newly added instructions (instruction is in
// computation, but not in schedule) which use X. If an instruction is not in
// the map, then it has no users which are newly added instructions.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const HloInstruction*>>
+ absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
new_instruction_uses;
// For each newly added instruction, this is the count of the instruction's
// operands that have not yet been scheduled. When this value reaches zero,
// then the instruction may be placed in the schedule.
- tensorflow::gtl::FlatMap<const HloInstruction*, int>
- unscheduled_operand_count;
+ absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
// Create a worklist of newly added instructions which are ready to be added
// to the schedule. Initialize worklist with those that have zero operands.
@@ -211,15 +211,15 @@ Status HloSchedule::Update() {
if (sequences_.size() > nonfusion_computations.size()) {
// Schedule contains some computations which have been removed from the
// HloModule. Remove them from the schedule as well.
- tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ absl::flat_hash_set<int64> nonfusion_computations_ids;
for (const HloComputation* computation : nonfusion_computations) {
nonfusion_computations_ids.insert(computation->unique_id());
}
for (auto it = sequences_.begin(); it != sequences_.end();) {
if (nonfusion_computations_ids.count(it->first) == 0) {
- it = sequences_.erase(it);
+ sequences_.erase(it++);
} else {
- it++;
+ ++it;
}
}
}
@@ -254,7 +254,7 @@ Status HloSchedule::Verify() const {
// For each computation verify the set of instructions is the same and that
// each dependency and control edge is honored.
for (const HloComputation* computation : nonfusion_computations) {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ absl::flat_hash_map<const HloInstruction*, int> instruction_position;
int pos = 0;
for (const HloInstruction* instruction :
sequence(computation).instructions()) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 270fe6039f..0a714101ee 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -103,8 +104,7 @@ class HloSchedule {
// Returns a map from HloComputation unique ID to instruction sequence. The
// map contains all sequences in the schedule.
- const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
- const {
+ const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const {
return sequences_;
}
@@ -148,7 +148,7 @@ class HloSchedule {
// A map from computation unique ID to instruction sequence. Unique IDs are
// used rather than HloComputation pointers because HLO pointers are not
// unique across HLO transformations because pointers may be recycled.
- tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+ absl::flat_hash_map<int64, HloInstructionSequence> sequences_;
};
std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index de7e6b53d4..188f4acc79 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -369,10 +370,28 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return HloSharding(tuple_shardings);
} else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
- proto.tile_assignment_devices().size() == 1) {
+ } else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0));
}
+
+ TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL)
+ << "Maximal sharding is expected to have single device assignment, but "
+ << proto.tile_assignment_devices().size() << " has provided.";
+
+ TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
+ TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
+
+ // RE: the product of tile assignment tensor dimensions must be
+ // equal to tile_assignment_devices.size().
+ int64 product_of_dimensions = 1;
+ for (auto dimension : proto.tile_assignment_dimensions()) {
+ TF_RET_CHECK(dimension > 0);
+ product_of_dimensions =
+ MultiplyWithoutOverflow(product_of_dimensions, dimension);
+ TF_RET_CHECK(product_of_dimensions > 0);
+ }
+ TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index d1cf644f82..fa34bddde1 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -22,7 +22,7 @@ namespace xla {
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
-class HloSubcomputationUnification : public HloPassInterface {
+class HloSubcomputationUnification : public HloModulePass {
public:
absl::string_view name() const override {
return "subcomputation-unification";
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 773fc7d225..59594ab2f0 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
CHECK_LE(operand_number, 2);
return operand_number == 0 || index.empty();
+ case HloOpcode::kDomain:
case HloOpcode::kTuple:
// These instructions always pass through their operands transparently.
return false;
@@ -166,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses(
positions_.insert(positions_.end(), positions.begin(), positions.end());
// Gather the computation roots at which this value appears.
- tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
+ absl::flat_hash_set<HloInstruction*> root_positions;
for (const HloPosition& position : positions_) {
if (position.instruction ==
position.instruction->parent()->root_instruction()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 0f6ecd42f6..496fe1795d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <set>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -315,7 +315,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension)))
+ operand_shape.dimensions(operand_dimension)))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return Status::OK();
@@ -549,6 +549,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
+ case HloOpcode::kSort:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
break;
@@ -764,7 +765,136 @@ Status VerifyHloStructure(HloModule* module) {
return Status::OK();
}
-Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal.
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ // Tokens cannot be passed as entry parameters.
+ // TODO(b/80000000): Remove this constraint.
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()));
+ }
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions share the same channel id.
+Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ if (instr1->channel_id() != instr2->channel_id()) {
+ return InternalError(
+ "Expected to have the same channel id, actual channel ids are: %s "
+ "(%d), %s (%d)",
+ instr1->ToString(), instr1->channel_id(), instr2->ToString(),
+ instr2->channel_id());
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions have the same is_host_transfer
+// attribute value. Intsructions must be send/recv instructions or their
+// 'done' variant.
+Status CheckSameIsHostTransfer(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ const HloSendRecvInstruction* send_recv1 =
+ DynCast<const HloSendRecvInstruction>(instr1);
+ const HloSendRecvInstruction* send_recv2 =
+ DynCast<const HloSendRecvInstruction>(instr2);
+ TF_RET_CHECK(send_recv1 != nullptr);
+ TF_RET_CHECK(send_recv2 != nullptr);
+ if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
+ return InternalError(
+ "Expected instructions to have the same is-host-transfer property: "
+ "%s, "
+ "%s ",
+ instr1->ToString(), instr2->ToString());
+ }
+ return Status::OK();
+}
+
+// Checks various invariants of send and recv instructions.
+Status VerifySendsAndRecvs(const HloModule& module) {
+ absl::flat_hash_map<int64, const HloInstruction*> host_channels;
+ // Host send/recv instructions must have their own unique channel.
+ auto check_unique_host_channel = [&](const HloInstruction* instruction) {
+ const HloSendRecvInstruction* sendrecv =
+ DynCast<const HloSendRecvInstruction>(instruction);
+ if (sendrecv->is_host_transfer()) {
+ auto it_inserted =
+ host_channels.insert({sendrecv->channel_id(), sendrecv});
+ if (!it_inserted.second) {
+ return FailedPrecondition(
+ "Channel %d is used for multiple host send/recv instructions: "
+ "%s "
+ "and "
+ "%s",
+ sendrecv->channel_id(), sendrecv->ToString(),
+ it_inserted.first->second->ToString());
+ }
+ }
+
+ return Status::OK();
+ };
+
+ // Send/Recv instruction must have a single user: the corresponding
+ // SendDone/RecvDone. with matching channel.
+ for (const HloComputation* computation : module.computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* send_done = instruction->users().front();
+ TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
+ break;
+ }
+ case HloOpcode::kRecv: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* recv_done = instruction->users().front();
+ TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
+ break;
+ }
+ case HloOpcode::kSendDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
+ break;
+ case HloOpcode::kRecvDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// CHECKs various invariants of a fusion instruction.
+Status CheckFusionInstruction(HloInstruction* fusion) {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
@@ -867,50 +997,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
}
+ TF_RET_CHECK(fusion->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {fusion->fused_instructions_computation()}))
+ << "Fusion HLO calls computations other than the "
+ "fused_instructions_computation: "
+ << fusion->ToString() << " fusion->fused_instructions_computation(): "
+ << fusion->fused_instructions_computation()->ToString()
+ << " fusion->called_computations(): "
+ << ComputationsToString(fusion->called_computations());
+
+ for (const auto& fused : fusion->fused_instructions()) {
+ TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
+ << "Fused HLO was missing a parent: " << fused->ToString()
+ << " parent: " << fused->parent()
+ << " computation: " << fusion->parent();
+ }
+
// TODO(b/65423525): We'd like to check that all operands are distinct.
// This is currently disabled due to the invariant being violated by
// multi-output fusion.
return Status::OK();
}
-Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
- auto* while_cond = instruction->while_condition();
- auto* while_body = instruction->while_body();
- if (while_cond->num_parameters() != 1) {
- return FailedPrecondition(
- "While condition must have exactly 1 parameter; had %d : %s",
- while_cond->num_parameters(), while_cond->ToString());
- }
- if (while_body->num_parameters() != 1) {
- return FailedPrecondition(
- "While body must have exactly 1 parameter; had %d : %s",
- while_body->num_parameters(), while_body->ToString());
- }
- if (instruction->operand_count() != 1) {
- return FailedPrecondition(
- "While loop must have exactly one operand; had %d : %s",
- instruction->operand_count(), instruction->ToString());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
- if (instruction->true_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "True computation %s of %s must have 1 parameter insted of %d",
- instruction->true_computation()->name(), instruction->ToString(),
- instruction->true_computation()->num_parameters());
- }
- if (instruction->false_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "False computation %s of %s must have 1 parameter insted of %d",
- instruction->false_computation()->name(), instruction->ToString(),
- instruction->false_computation()->num_parameters());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
+// Checks that the non-scalar operand shapes are compatible to the output
+// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
+Status CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
@@ -927,199 +1039,158 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
-namespace {
+// Visitor which verifies various fields on the HLO instruction. This class does
+// not check result shape as that is checked in the ShapeVerifier.
+class InstructionVerifier : public DfsHloVisitorWithDefault {
+ public:
+ explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
+ : instruction_can_change_layout_func_(
+ instruction_can_change_layout_func) {}
-// Returns true if the given Shape has a TOKEN shape as any subshape.
-bool ShapeContainsToken(const Shape& shape) {
- bool contains_token = false;
- ShapeUtil::ForEachSubshape(
- shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsToken(subshape)) {
- contains_token = true;
- }
- });
- return contains_token;
-}
+ Status DefaultAction(HloInstruction*) override { return Status::OK(); }
-// Verifies that all types entering and exiting the entry computation are
-// legal.
-Status VerifyEntryAndExitShapes(const HloModule& module) {
- // Tokens cannot be passed as entry parameters.
- // TODO(b/80000000): Remove this constraint.
- for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
- HloInstruction* param =
- module.entry_computation()->parameter_instruction(i);
- if (ShapeContainsToken(param->shape())) {
- return InternalError(
- "Entry parameter %d is or contains a token shape: %s", i,
- ShapeUtil::HumanString(param->shape()));
- }
+ Status HandleFusion(HloInstruction* fusion) override {
+ return CheckFusionInstruction(fusion);
}
- return Status::OK();
-}
-// Checks if the given two instructions share the same channel id.
-Status CheckSameChannel(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- if (instr1->channel_id() != instr2->channel_id()) {
- return InternalError(
- "Expected to have the same channel id, actual channel ids are: %s "
- "(%d), %s (%d)",
- instr1->ToString(), instr1->channel_id(), instr2->ToString(),
- instr2->channel_id());
+ Status HandleBroadcast(HloInstruction* broadcast) override {
+ // If you see this failure then someone has confused the difference
+ // between the HLO broadcast op, and the UserComputation broadcast
+ // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
+ // or ComputationLowerer::Visit()
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(broadcast->operand(0)->shape()))
+ << "Broadcast HLO (" << broadcast->ToShortString()
+ << ") has invalid number of dimensions: "
+ << broadcast->dimensions().size()
+ << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape());
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks if the given two instructions have the same is_host_transfer
-// attribute value. Intsructions must be send/recv instructions or their
-// 'done' variant.
-Status CheckSameIsHostTransfer(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- const HloSendRecvInstruction* send_recv1 =
- DynCast<const HloSendRecvInstruction>(instr1);
- const HloSendRecvInstruction* send_recv2 =
- DynCast<const HloSendRecvInstruction>(instr2);
- TF_RET_CHECK(send_recv1 != nullptr);
- TF_RET_CHECK(send_recv2 != nullptr);
- if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
- return InternalError(
- "Expected instructions to have the same is-host-transfer property: "
- "%s, "
- "%s ",
- instr1->ToString(), instr2->ToString());
+ Status HandleWhile(HloInstruction* xla_while) override {
+ auto* while_cond = xla_while->while_condition();
+ auto* while_body = xla_while->while_body();
+ if (while_cond->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While condition must have exactly 1 parameter; had %d : %s",
+ while_cond->num_parameters(), while_cond->ToString());
+ }
+ if (while_body->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While body must have exactly 1 parameter; had %d : %s",
+ while_body->num_parameters(), while_body->ToString());
+ }
+ if (xla_while->operand_count() != 1) {
+ return FailedPrecondition(
+ "While loop must have exactly one operand; had %d : %s",
+ xla_while->operand_count(), xla_while->ToString());
+ }
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks various invariants of send and recv instructions.
-Status VerifySendsAndRecvs(const HloModule& module) {
- tensorflow::gtl::FlatMap<int64, const HloInstruction*> host_channels;
- // Host send/recv instructions must have their own unique channel.
- auto check_unique_host_channel = [&](const HloInstruction* instruction) {
- const HloSendRecvInstruction* sendrecv =
- DynCast<const HloSendRecvInstruction>(instruction);
- if (sendrecv->is_host_transfer()) {
- auto it_inserted =
- host_channels.insert({sendrecv->channel_id(), sendrecv});
- if (!it_inserted.second) {
- return FailedPrecondition(
- "Channel %d is used for multiple host send/recv instructions: "
- "%s "
- "and "
- "%s",
- sendrecv->channel_id(), sendrecv->ToString(),
- it_inserted.first->second->ToString());
- }
+ Status HandleConditional(HloInstruction* conditional) override {
+ if (conditional->true_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "True computation %s of %s must have 1 parameter insted of %d",
+ conditional->true_computation()->name(), conditional->ToString(),
+ conditional->true_computation()->num_parameters());
}
+ if (conditional->false_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "False computation %s of %s must have 1 parameter insted of %d",
+ conditional->false_computation()->name(), conditional->ToString(),
+ conditional->false_computation()->num_parameters());
+ }
+ return Status::OK();
+ }
+
+ Status HandleElementwiseUnary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+
+ Status HandleElementwiseBinary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+ Status HandleGetTupleElement(HloInstruction* gte) override {
+ TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape()));
return Status::OK();
- };
+ }
- // Send/Recv instruction must have a single user: the corresponding
- // SendDone/RecvDone. with matching channel.
- for (const HloComputation* computation : module.computations()) {
- for (const HloInstruction* instruction : computation->instructions()) {
- switch (instruction->opcode()) {
- case HloOpcode::kSend: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* send_done = instruction->users().front();
- TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
- break;
- }
- case HloOpcode::kRecv: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* recv_done = instruction->users().front();
- TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
- break;
+ Status HandleTranspose(HloInstruction* transpose) override {
+ const Shape& shape = transpose->shape();
+ const HloInstruction* operand = transpose->operand(0);
+ TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
+ TF_RET_CHECK(shape.dimensions().size() ==
+ transpose->operand(0)->shape().dimensions().size());
+ TF_RET_CHECK(std::equal(
+ operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(transpose->dimensions(), shape.dimensions()).begin()))
+ << "shape: " << shape << ", operand->shape(): " << shape
+ << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
+ << "}";
+ return Status::OK();
+ }
+
+ Status Preprocess(HloInstruction* instruction) override {
+ auto previous = instructions_by_name_.find(instruction->name());
+ TF_RET_CHECK(previous == instructions_by_name_.end())
+ << "HLO has name that is not unique within module:\n"
+ << instruction->ToString()
+ << " in computation: " << instruction->parent()->name()
+ << "\nPrevious HLO with same name:\n"
+ << previous->second->ToString()
+ << " in computation: " << previous->second->parent()->name();
+ instructions_by_name_[instruction->name()] = instruction;
+ return Status::OK();
+ }
+
+ Status Postprocess(HloInstruction* instruction) override {
+ if (instruction_can_change_layout_func_ &&
+ LayoutUtil::IsDenseArray(instruction->shape()) &&
+ !instruction_can_change_layout_func_(instruction)) {
+ const Shape& result_shape = instruction->shape();
+ const Layout& result_layout = result_shape.layout();
+ for (HloInstruction* operand : instruction->operands()) {
+ const Shape& operand_shape = operand->shape();
+ if (LayoutUtil::IsDenseArray(operand_shape) &&
+ ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
+ const Layout& operand_layout = operand_shape.layout();
+ TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
+ << "Instruction shouldn't change layouts "
+ << instruction->ToString() << " From "
+ << ShapeUtil::HumanString(result_shape) << " To "
+ << ShapeUtil::HumanString(operand_shape);
}
- case HloOpcode::kSendDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
- break;
- case HloOpcode::kRecvDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
- break;
- default:
- break;
}
}
+
+ return Status::OK();
}
- return Status::OK();
-}
+
+ private:
+ absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
+};
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
- tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions;
-
for (auto* computation : module->computations()) {
- for (const auto& instruction : computation->instructions()) {
- TF_RET_CHECK(instruction->parent() == computation);
- if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(instruction->called_computations() ==
- absl::Span<HloComputation* const>(
- {instruction->fused_instructions_computation()}))
- << "Fusion HLO calls computations other than the "
- "fused_instructions_computation: "
- << instruction->ToString()
- << " instruction->fused_instructions_computation(): "
- << instruction->fused_instructions_computation()->ToString()
- << " instruction->called_computations(): "
- << ComputationsToString(instruction->called_computations());
-
- for (const auto& fused : instruction->fused_instructions()) {
- TF_RET_CHECK(fused->parent() ==
- instruction->fused_instructions_computation())
- << "Fused HLO was missing a parent: " << fused->ToString()
- << " parent: " << fused->parent()
- << " computation: " << computation;
- }
- } else if (instruction->opcode() == HloOpcode::kBroadcast) {
- // If you see this failure then someone has confused the difference
- // between the HLO broadcast op, and the UserComputation broadcast
- // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
- // or ComputationLowerer::Visit()
- TF_RET_CHECK(instruction->dimensions().size() ==
- ShapeUtil::Rank(instruction->operand(0)->shape()))
- << "Broadcast HLO (" << instruction->ToShortString()
- << ") has invalid number of dimensions: "
- << instruction->dimensions().size()
- << " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
- } else if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
- } else if (instruction->opcode() == HloOpcode::kConditional) {
- TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
- } else if (instruction->opcode() !=
- HloOpcode::kRng /* Rng operands are always scalar. */
- && instruction->IsElementwise()) {
- TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction));
- }
-
- auto previous = instructions.find(instruction->name());
- TF_RET_CHECK(previous == instructions.end())
- << "HLO has name that is not unique within module:\n"
- << instruction->ToString()
- << " in computation: " << computation->name()
- << "\nPrevious HLO with same name:\n"
- << previous->second->ToString()
- << " in computation: " << previous->second->parent()->name();
- instructions[instruction->name()] = instruction;
- }
-
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
+
+ InstructionVerifier instruction_verifier(
+ instruction_can_change_layout_func_);
+ TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 42e3027bf1..cb49cb95ba 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -151,15 +151,21 @@ class ShapeVerifier : public DfsHloVisitor {
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
-class HloVerifier : public HloPassInterface {
+class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {})
: shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
return absl::make_unique<ShapeVerifier>(layout_sensitive,
allow_mixed_precision);
- }) {}
+ }),
+ instruction_can_change_layout_func_(
+ std::move(instruction_can_change_layout_func)) {
+ CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
+ }
// Uses custom shape verification.
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
@@ -172,22 +178,15 @@ class HloVerifier : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
private:
- // CHECKs various invariants of a fusion instruction.
- Status CheckFusionInstruction(HloInstruction* fusion) const;
-
- Status CheckWhileInstruction(HloInstruction* instruction);
-
- Status CheckConditionalInstruction(HloInstruction* instruction);
-
- // Checks that the non-scalar operand shapes are compatible to the output
- // shape, i.e., that there are no implicit broadcasts of size-one dimensions.
- Status CheckElementwiseInstruction(HloInstruction* instruction);
-
// Creates a ShapeVerifier that checks that shapes match inferred
// expectations. This is a factory function because ShapeVerifier,
// being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
ShapeVerifierFactory shape_verifier_factory_;
+
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 8f0423bb1c..afe01e5487 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase {
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
+class HloVerifierTestLayoutSensitive : public HloTestBase {
+ public:
+ HloVerifierTestLayoutSensitive()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false,
+ LayoutAssignment::InstructionCanChangeLayout) {}
+};
+
TEST_F(HloVerifierTest, NullInstructionParent) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
HasSubstr("non-positive base area dilation factor"));
}
+static const char* const kAddWithLayoutChangeHlo = R"(
+ HloModule AddWithLayoutChange
+ ENTRY AddWithLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[3,4]{0,1} parameter(1)
+ ROOT add0 = f32[3,4]{1,0} add(par0,par1)
+ }
+ )";
+
+TEST_F(HloVerifierTest, AddWithLayoutChange) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
+ const char* const kSliceWithLayoutChangeHlo = R"(
+ HloModule SliceWithLayoutChange
+ ENTRY SliceWithLayoutChange {
+ par0 = f32[4,5]{0,1} parameter(0)
+ par1 = s32[2] parameter(1)
+ ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
+ dynamic_slice_sizes={3,4}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kSliceWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
+ const char* const kConcatWithLayoutChangeHlo = R"(
+ HloModule ConcatWithLayoutChange
+ ENTRY ConcatWithLayoutChange {
+ par0 = f32[3,5]{0,1} parameter(0)
+ par1 = f32[3,3]{1,0} parameter(1)
+ ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
+ dimensions={1}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kConcatWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index 85bb4a8b24..9c48b7db61 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -25,7 +25,7 @@ namespace xla {
// Pass which replaces all implicit broadcasts with their equivalent sequence of
// explicit broadcast and reshape instructions.
-class ImplicitBroadcastRemover : public HloPassInterface {
+class ImplicitBroadcastRemover : public HloModulePass {
public:
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 06f0e1ed25..1ebb331977 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -23,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
@@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
absl::InlinedVector<const HloInstruction*, 4> stack;
enum DfsState { kDiscovered, kVisited };
- gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
+ absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
stack.push_back(root);
InsertOrDie(&dfs_state_map, root, kDiscovered);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index df9cbab915..e5aa67fd85 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <type_traits>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/util/ptr_util.h"
namespace xla {
@@ -360,13 +360,13 @@ class IndexedArrayAnalysis {
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<Literal> owned_literals_;
- tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
+ absl::flat_hash_map<const HloInstruction*, Array*> cache_;
};
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
// unconditionally add to the regular HLO pass pipeline.
-class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+class IndexedArrayAnalysisPrinterPass : public HloModulePass {
public:
absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 3fdc2cee9a..69a4c160ee 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,11 +22,12 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -188,13 +189,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_duplicate) {
+ const HloInstructionSet& do_not_fuse,
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache) {
if (consumer == producer) {
return true;
}
if (!consumer->IsFusible()) {
return false;
}
+ auto cache_it = result_cache->find(std::make_pair(producer, consumer));
+ if (cache_it != result_cache->end()) {
+ return cache_it->second;
+ }
+ bool result = true;
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
@@ -202,20 +210,23 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
- if (do_not_duplicate.count(consumer_operand) > 0 ||
- !ShouldFuse(consumer, i)) {
- return false;
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
+ result = false;
+ break;
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
// producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
- if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
- return false;
+ if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
+ result_cache)) {
+ result = false;
+ break;
}
}
- return true;
+ result_cache->emplace(std::make_pair(producer, consumer), result);
+ return result;
}
InstructionFusion::HloInstructionSet
@@ -231,6 +242,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// fusing operations that require duplication later depending on
// is_expensive_().
HloInstructionSet do_not_duplicate;
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
+ can_fuse_on_all_paths_result_cache;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
if (do_not_duplicate.count(producer) > 0) {
@@ -286,7 +299,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
if (producer->IsFusible() &&
- CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
+ CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
+ &can_fuse_on_all_paths_result_cache)) {
continue;
}
do_not_duplicate.insert(producer);
@@ -417,7 +431,7 @@ class ReversePostOrderFusionQueue : public FusionQueue {
private:
std::vector<HloInstruction*> post_order_;
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+ absl::flat_hash_map<HloInstruction*, int> post_order_index_;
};
} // namespace
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c1fde8ecfc..f14c667520 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -1,3 +1,4 @@
+#include "absl/container/flat_hash_map.h"
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -24,39 +26,12 @@ limitations under the License.
namespace xla {
-// A queue interface that allows implementations to choose fusion candidates in
-// custom order.
-class FusionQueue {
- public:
- FusionQueue() = default;
- virtual ~FusionQueue() = default;
-
- // Dequeues the next fusion candidates: a consumer and the list of producers
- // as operand indices.
- virtual std::pair<HloInstruction*, std::vector<int64>>
- DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
-
- // A callback passed to the queue implementation right before the producer is
- // fused into the consumer.
- virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
-
- // A callback passed to the queue implementation right after the fusion is
- // created. Note that original_producer could have been destroyed.
- virtual void OnFusingInstruction(HloInstruction* fusion,
- HloInstruction* original_producer,
- HloInstruction* original_consumer) {}
-
- // A callback passed to the queue implementation to notify the removal of an
- // instruction.
- virtual void RemoveInstruction(HloInstruction* instruction) = 0;
-};
-
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
-class InstructionFusion : public HloPassInterface {
+class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
@@ -151,8 +126,15 @@ class InstructionFusion : public HloPassInterface {
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
- bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_fuse);
+ //
+ // A map from <producer, consumer> to a bool is required as the result cache
+ // to store and query the results of calls to this function, in order to avoid
+ // repeated computations.
+ bool CanFuseOnAllPaths(
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloInstructionSet& do_not_fuse,
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 146c9052f1..1484e14df1 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -45,8 +45,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
- "//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:layout_assignment",
+ "//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index bb69cb9c47..7c79eb7d79 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -28,9 +28,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
-#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout);
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 082bf8bffe..cc4a342e9d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ CHECK(get_channel_constraints(instruction))
+ << "Multi-module layout assignment requires ChannelLayoutConstraints";
+ int64 all_reduce_id = instruction->all_reduce_id().value();
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(all_reduce_id)) {
+ continue;
+ }
+ // TODO(b/68493863): Change to use SetOperandLayout().
+ const Shape& buffer_shape = instruction->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@@ -776,21 +792,27 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
<< " instruction: " << instruction->ToString();
if (ShapeUtil::IsTuple(instruction->shape())) {
- // Deep-copy tuples.
+ // Copy tuple elements which have differing layouts.
std::vector<HloInstruction*> element_copies;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
++i) {
+ const Shape& target_shape =
+ ShapeUtil::GetSubshape(shape_with_layout, {i});
+ const Shape& instr_shape =
+ ShapeUtil::GetSubshape(instruction->shape(), {i});
HloInstruction* gte = instruction->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
- i));
- SetupCopiedInstruction(*instruction, gte, {i});
- // Recurse to copy each elements.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element_copy,
- CreateCopyWithNewLayout(
- ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
- element_copies.push_back(element_copy);
+ HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
+
+ if (ShapeUtil::Equal(target_shape, instr_shape)) {
+ // Shapes and layouts are equal, no need to copy.
+ element_copies.push_back(gte);
+ } else {
+ SetupCopiedInstruction(*instruction, gte, {i});
+ // Recurse to copy each element.
+ TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
+ CreateCopyWithNewLayout(target_shape, gte));
+ element_copies.push_back(element_copy);
+ }
}
// Gather element copies into a tuple with a new Tuple instruction.
HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
@@ -958,10 +980,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
+
saved_entry_computation_layout_(*entry_computation_layout),
- channel_layout_constraints_(channel_constraints) {
+ channel_layout_constraints_(channel_constraints),
+ instruction_can_change_layout_func_(
+ std::move(instruction_can_change_layout_func)) {
if (channel_layout_constraints_ != nullptr) {
// Save a copy of the input ChannelLayoutConstraints so that we can reset it
// if we have to undo previous operations (ClearPreviousPassSideEffects()).
@@ -982,7 +1009,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape()) &&
- InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) {
+ !instruction_can_change_layout_func_(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
//
@@ -1060,7 +1087,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
- InstructionRequiresInputLayoutEqualToOutputLayout(user)) {
+ !instruction_can_change_layout_func_(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
}
@@ -1512,19 +1539,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
-
- // Copy the root instruction's result if its layout does not match the result
- // layout constraint.
- if (constraints.ResultLayout() != nullptr &&
- !constraints.ResultLayout()->MatchesLayoutInShape(
- computation->root_instruction()->shape())) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_root,
- CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
- computation->root_instruction()));
- computation->set_root_instruction(new_root);
- }
-
return Status::OK();
}
@@ -1654,6 +1668,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
+
+ // Copy the root instruction's result if its layout does not match the result
+ // layout constraint.
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
return Status::OK();
}
@@ -1709,6 +1735,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->all_reduce_id().value(),
+ instruction->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the channel wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ HloInstruction* operand = instruction->mutable_operand(0);
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ *instruction->mutable_shape() = shape;
+ }
}
}
return Status::OK();
@@ -1803,7 +1853,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
return true;
}
-bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
+/* static */
+bool LayoutAssignment::InstructionCanChangeLayout(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kAbs:
@@ -1869,7 +1920,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTanh:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
- return true;
+ return false;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
@@ -1900,7 +1951,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
- return false;
+ return true;
}
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index cf545031d3..2d48e12263 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -25,6 +25,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -38,8 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -228,8 +228,8 @@ class LayoutConstraints {
// Array-shaped buffers which have not yet been constrained.
std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
- mutable tensorflow::gtl::FlatMap<const HloInstruction*,
- std::unique_ptr<PointsToSet::BufferSet>>
+ mutable absl::flat_hash_map<const HloInstruction*,
+ std::unique_ptr<PointsToSet::BufferSet>>
buffer_sets_cache_;
HloComputation* computation_;
@@ -281,11 +281,16 @@ class ChannelLayoutConstraints {
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
-class LayoutAssignment : public HloPassInterface {
+class LayoutAssignment : public HloModulePass {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
//
+ // instruction_can_change_layout_func is a function object that determines
+ // whether an instruction can change layouts. An instruction not being able to
+ // change layout means that it requires operands with the same rank as the
+ // output to have the same layout as the output.
+ //
// channel_constraints is both an input and output. Any sends or recvs that
// are present in channel_constraints will be laid out as constrained. Any
// unconstrained sends or recvs will be laid out as locally optimal and their
@@ -295,6 +300,8 @@ class LayoutAssignment : public HloPassInterface {
// within any module passed to `Run`.
explicit LayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = InstructionCanChangeLayout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
absl::string_view name() const override { return "layout-assignment"; }
@@ -303,10 +310,10 @@ class LayoutAssignment : public HloPassInterface {
// (any layouts were changed).
StatusOr<bool> Run(HloModule* module) override;
- // Returns true if the instruction requires that operands with the same rank
- // as the output have to have the same layout as the output.
- virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
- const HloInstruction* instruction);
+ // Determines whether an instruction can change layouts. An instruction not
+ // being able to change layout means that it requires operands with the same
+ // rank as the output to have the same layout as the output.
+ static bool InstructionCanChangeLayout(const HloInstruction* instruction);
protected:
// These methods, invoked by PropagateConstraints, propagate a layout
@@ -504,7 +511,7 @@ class LayoutAssignment : public HloPassInterface {
// Every copy added to the module by the layout assignment pass is registered
// here.
- tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+ absl::flat_hash_set<HloInstruction*> added_copies_;
// The pointer to the channel layout constraints passed in with the
// constructor. If not nullptr, this is an input/output argument.
@@ -521,8 +528,10 @@ class LayoutAssignment : public HloPassInterface {
// The set of HLO instructions which lacked any layout constraint, thus
// receiving propagated default layouts.
- tensorflow::gtl::FlatSet<const HloInstruction*>
- unconstrained_layout_instructions_;
+ absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
+
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 752a61476d..2c549cd872 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {
LayoutAssignment layout_assignment(
- entry_computation_layout, /*channel_constraints=*/channel_constraints);
+ entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ /*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
@@ -860,6 +861,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
+TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
+ // Pin non matching layouts to parameter and root.
+ const char* module_str = R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry_computation {
+ param = (f32[2,2]) parameter(0)
+ gte = f32[2,2] get-tuple-element(param), index=0
+ ar.0 = f32[2,2] cross-replica-sum(gte),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=0}
+ const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
+ ROOT ar.1 = f32[2,2] cross-replica-sum(const),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=1}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ Shape param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+ TF_ASSERT_OK(
+ computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
+ param_shape));
+ computation_layout.mutable_result_layout()->ResetLayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ ChannelLayoutConstraints channel_constraints;
+ AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+
+ EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1));
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
+}
+
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopySliceOperandToAvoidImplicitLayoutChange
@@ -998,5 +1043,64 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
op::ShapeWithLayout(shape_copy))));
}
+TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
+ // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
+ // The mismatch forces a copy of the tuple. The tuple contains a token, so
+ // layout assignment will fail if it tries to copy the whole tuple.
+ const char* module_str = R"(
+ HloModule TupleCopyOnLayoutMismatch
+
+ condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
+ tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
+ counter.1 = s32[] get-tuple-element(tup.1), index=0
+ five = s32[] constant(5)
+ ROOT lt = pred[] less-than(counter.1, five)
+ }
+
+ body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
+ tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
+ counter.2 = s32[] get-tuple-element(tup.2), index=0
+ tok.2 = token[] get-tuple-element(tup.2), index=1
+
+ ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
+ next_tok = token[] get-tuple-element(ifeed.2), index=1
+ next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
+
+ one = s32[] constant(1)
+ next_counter = s32[] add(counter.2, one)
+ ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
+ }
+
+ ENTRY main () -> f32[512,1024]{0,1} {
+ start_tok = token[] after-all()
+
+ ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
+ itok = token[] get-tuple-element(ifeed.3), index=1
+ ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
+
+ zero = s32[] constant(0)
+ itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
+
+ loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
+ ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ ComputationLayout computation_layout(
+ module().entry_computation()->ComputeProgramShape());
+
+ // Sanity check to verify that there's a layout mismatch.
+ EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0));
+
+ AssignLayouts(&module(), &computation_layout);
+
+ // Make sure that layout assignment did not magically eliminate the mismatch,
+ // in which case the test didn't prove anything.
+ EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 540bbb7c7a..6223a34b12 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -38,6 +38,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm//:core",
],
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index e5370eca56..643ecd0fba 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
-#include <unordered_set>
+#include <map>
#include "llvm/IR/MDBuilder.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
add_buffers_to_worklist(operand);
}
- tensorflow::gtl::FlatSet<BufferAllocation::Slice,
- BufferAllocation::Slice::Hasher>
- buffers;
+ std::set<BufferAllocation::Slice> buffers;
for (const LogicalBuffer* buffer : worklist) {
// Skip buffers which cannot be added to the noalias set.
if (!assignment.HasAllocation(*buffer) ||
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index 8d9fa99d82..2b46b3c396 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace llvm_ir {
@@ -77,14 +76,14 @@ class AliasAnalysis {
// A map from a buffer slice to metadata corresponding to its alias.scope
// metadata. The index kParameterAliasSet is used to hold aliasing
// information for parameters.
- tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
- BufferAllocation::Slice::Hasher>
+ absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
+ BufferAllocation::Slice::Hasher>
alias_scope_metadata_;
// A map from a buffer slice to metadata corresponding to its noalias
// metadata.
- tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
- BufferAllocation::Slice::Hasher>
+ absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
+ BufferAllocation::Slice::Hasher>
noalias_metadata_;
};
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index eaa09591b7..ec52a24d78 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() {
// so reserve 10% more than the number of instructions to avoid frequent
// resizes.
logical_buffers_.clear();
- logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10);
+ logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
// We filter out fusion computations, and get to them through fusion
// instructions. This is because it's possible to have orphaned (unreachable)
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc
index 5fd779ebf9..2200ef054a 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/map_inliner.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <string>
@@ -32,10 +32,10 @@ limitations under the License.
namespace xla {
-// InlinerVisitor traverses the HLO computation and inlines maps.
-class InlinerVisitor : public DfsHloVisitorWithDefault {
+// MapInlinerVisitor traverses the HLO computation and inlines maps.
+class MapInlinerVisitor : public DfsHloVisitorWithDefault {
public:
- explicit InlinerVisitor(HloComputation* computation)
+ explicit MapInlinerVisitor(HloComputation* computation)
: computation_(computation) {}
// Default visitor action is to do nothing and return OK.
@@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> Run(HloComputation* computation);
private:
- // Current HloComputation instance the InlinerVisitor is traversing.
+ // Current HloComputation instance the MapInlinerVisitor is traversing.
HloComputation* computation_;
// Whether algebraic simplification has occurred.
bool changed_ = false;
};
-StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
+StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
changed_ = false;
computation_ = computation;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
return changed_;
}
-Status InlinerVisitor::HandleMap(HloInstruction* map) {
+Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
HloComputation* function = map->to_apply();
HloInstruction& root = *function->root_instruction();
- // TODO(b/29249531): Add DCE pass to remove unused HloComputations.
// Only inlining functions that are simply a single operation until a better
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
if (root.opcode() == HloOpcode::kFusion ||
- root.opcode() == HloOpcode::kParameter ||
root.opcode() == HloOpcode::kTrace) {
// Cloning not supported for these instructions.
return Status::OK();
}
VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
<< root.ToShortString();
- // If the input is a constant then the shape of the constant could be
- // different than the map shape. Hence, a broadcast is needed, else the
- // cloned operand with new shape and operands work.
- if (root.opcode() != HloOpcode::kConstant) {
- std::vector<HloInstruction*> params;
- for (int64 o = 0; o < root.operands().size(); o++) {
- params.push_back(map->operands()[root.operand(o)->parameter_number()]);
- }
- HloInstruction* placed_instruction = computation_->AddInstruction(
- root.CloneWithNewOperands(map->shape(), params));
+ if (root.opcode() == HloOpcode::kParameter) {
+ // If the root is a parameter, then use the corresponding operand as the
+ // result of the computation.
TF_RETURN_IF_ERROR(
- computation_->ReplaceInstruction(map, placed_instruction));
- } else {
+ map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
+ TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
+ } else if (root.opcode() == HloOpcode::kConstant) {
+ // If the input is a constant then the shape of the constant could be
+ // different than the map shape. Hence, a broadcast is needed, else the
+ // cloned operand with new shape and operands work.
+ //
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
@@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
+ } else {
+ std::vector<HloInstruction*> params;
+ for (int64 o = 0; o < root.operands().size(); o++) {
+ params.push_back(map->operands()[root.operand(o)->parameter_number()]);
+ }
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ root.CloneWithNewOperands(map->shape(), params));
+ TF_RETURN_IF_ERROR(
+ computation_->ReplaceInstruction(map, placed_instruction));
}
changed_ = true;
return Status::OK();
@@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
return Status::OK();
}
-StatusOr<bool> Inliner::Run(HloModule* module) {
- InlinerVisitor visitor(/*computation=*/nullptr);
+StatusOr<bool> MapInliner::Run(HloModule* module) {
+ MapInlinerVisitor visitor(/*computation=*/nullptr);
bool changed = false;
for (HloComputation* computation : module->computations()) {
TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h
index efa8ed3abc..b679118118 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/map_inliner.h
@@ -13,27 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
-// A pass which performs inlining. Which can result, for example, in functions
-// that were previously being mapped by Map instead directly applied to the
-// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloPassInterface {
+// A pass which performs map inlining. This replaces kMap instructions with
+// their equivalent sequence of array operations. For example:
+// map({X, Y}, add) -> add(X, Y)).
+class MapInliner : public HloModulePass {
public:
- ~Inliner() override = default;
- absl::string_view name() const override { return "inline"; }
+ ~MapInliner() override = default;
+ absl::string_view name() const override { return "map-inline"; }
- // Run inlining on the given computation. Returns whether the computation was
- // changed.
+ // Run map inlining on the given computation. Returns whether the computation
+ // was changed.
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc
index 7e967f035c..84059dd0f7 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/map_inliner_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <utility>
@@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloVerifiedTestBase;
+using MapInlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
-TEST_F(InlinerTest, MapMax) {
+TEST_F(MapInlinerTest, MapMax) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto max_builder = HloComputation::Builder(TestName());
@@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
@@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) {
}
// Test that `constant` function is changed to `broadcast`.
-TEST_F(InlinerTest, MapConstant) {
+TEST_F(MapInlinerTest, MapConstant) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto const2_builder = HloComputation::Builder(TestName());
@@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
@@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
-TEST_F(InlinerTest, MapSubtractOppositeOrder) {
+TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
// Note that the parameter ordinals are in the opposite order to their
@@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
@@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
+TEST_F(MapInlinerTest, MapParameter) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto param_builder = HloComputation::Builder(TestName());
+ param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
+ param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
+ auto param_f32 = param_builder.Build();
+
+ auto builder = HloComputation::Builder("MapParamFunction");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = CreateNewVerifiedModule();
+ hlo_module->AddEmbeddedComputation(std::move(param_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ MapInliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
+ auto expected = LiteralUtil::CreateR0<float>(4);
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index b9ec31c497..2ca527bc4c 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
all_fusion_candidates_.push_back(instruction);
std::vector<HloInstruction*> candidates;
- tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
+ absl::flat_hash_set<HloInstruction*> candidates_set;
VLOG(10) << "Looking at instruction: " << instruction->name();
for (auto operand : instruction->operands()) {
// Filter out the non-interesting instructions -- they
@@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
// Update the fusible list for fusion. Variable new_fusibles keeps
// track of the new or changed entries.
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
auto it = fusion_node.fusibles.begin();
while (it != fusion_node.fusibles.end()) {
HloInstruction* instr = it->first;
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d2c52651c4..9508ab2ed1 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -44,7 +45,7 @@ namespace xla {
// Note that the reachability map is updated based on the original computation.
// This works because the reachability is monotonically increasing with
// instruction fusion.
-class MultiOutputFusion : public HloPassInterface {
+class MultiOutputFusion : public HloModulePass {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
@@ -126,7 +127,7 @@ class MultiOutputFusion : public HloPassInterface {
std::vector<FusionCandidate> candidates_;
// A map that maps an instruction to the index_.
- tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_;
+ absl::flat_hash_map<HloInstruction*, int> candidates_index_;
// The reachability map of current computation.
std::unique_ptr<HloReachabilityMap> reachability_;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17a23..ac2f79674f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) {
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 6dd89c240f..8909d0f4fe 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <string>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -69,7 +69,7 @@ class NameUniquer {
int64 next_ = 0;
// Set of all the identifiers which has been used.
- tensorflow::gtl::FlatSet<int64> used_;
+ absl::flat_hash_set<int64> used_;
};
// The string to use to separate the prefix of the name from the uniquing
@@ -78,7 +78,7 @@ class NameUniquer {
// Map from name prefix to the generator data structure which tracks used
// identifiers and generates new ones.
- tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_;
+ absl::flat_hash_map<string, SequentialIdGenerator> generated_names_;
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
};
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 4869db79e7..380cde0e6a 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -17,8 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#include "absl/strings/string_view.h"
+#include "absl/utility/utility.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -116,15 +120,82 @@ namespace xla {
// .WithOperand(1, Op(&c))
// .WithOperand(2, Op(&d))
//
+
+struct MatchOption {
+ // If true, actually capture matched item into the user pointer.
+ bool capture;
+};
+
template <typename Value, typename Pattern>
-bool Match(Value* value, const Pattern& pattern) {
- return pattern.Match(value);
+bool Match(Value* value, const Pattern& pattern,
+ MatchOption option = {/*.capture=*/true}) {
+ if (option.capture) {
+ auto new_option = option;
+ new_option.capture = false;
+ if (!pattern.Match(value, new_option)) {
+ return false;
+ }
+ }
+ return pattern.Match(value, option);
}
namespace match {
namespace detail {
+template <typename Item, typename... Patterns>
+class AllOfPattern {
+ public:
+ explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ return std::get<index>(patterns_).Match(item, option) &&
+ MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return true;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
+} // namespace detail
+
+// Returns a pattern that represents the conjunction of all input patterns. All
+// patterns need to match in order to have the AllOf pattern match.
+//
+// TODO(timshen): Currently AllOf is still nested, e.g. AllOf<AllOf<A>, B> is
+// not AllOf<A, B>. We might want to flatten the AllOf type structure if the
+// C++ compile error message gets annoying.
+template <typename Item, typename... Patterns>
+detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
+ const Patterns&... patterns) {
+ return detail::AllOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
+namespace detail {
+
template <typename LayoutType, typename Impl>
class LayoutPattern;
@@ -132,57 +203,61 @@ class LayoutPattern;
// nullptr.
class LayoutPatternBaseImpl {
public:
- bool Match(const ::xla::Layout* layout) const { return layout != nullptr; }
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout != nullptr;
+ }
};
// A LayoutPattern implementation that matches only if the layout equals a
// Layout proto.
-template <typename Previous>
class LayoutPatternEqualImpl {
public:
- explicit constexpr LayoutPatternEqualImpl(const Previous& previous,
- const ::xla::Layout* layout)
- : previous_(previous), layout_(layout) {}
+ explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
+ : layout_(layout) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout);
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return LayoutUtil::Equal(*layout_, *layout);
}
private:
- Previous previous_;
const ::xla::Layout* layout_;
};
// A LayoutPattern implementation that matches only if the layout has a given
// format.
-template <typename Previous>
class LayoutPatternFormatImpl {
public:
- explicit constexpr LayoutPatternFormatImpl(const Previous& previous,
- Format format)
- : previous_(previous), format_(format) {}
+ explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && layout->format() == format_;
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout->format() == format_;
}
private:
- Previous previous_;
Format format_;
};
// A pattern that matches Layouts.
template <typename LayoutType, typename Impl>
class LayoutPattern {
+ private:
+ template <typename NewImpl>
+ LayoutPattern<LayoutType, AllOfPattern<::xla::Layout, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return LayoutPattern<LayoutType,
+ AllOfPattern<::xla::Layout, Impl, NewImpl>>(
+ AllOf<Layout>(impl_, std::move(new_impl)), matched_layout_);
+ }
+
public:
explicit constexpr LayoutPattern(const Impl& impl,
LayoutType** matched_layout)
: impl_(impl), matched_layout_(matched_layout) {}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(const ::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -191,9 +266,9 @@ class LayoutPattern {
}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -203,24 +278,21 @@ class LayoutPattern {
// Modifies the pattern to match only if the layout equals the given proto.
// The layout must outlive the returned pattern.
- constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo(
- const ::xla::Layout* layout) const {
- return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>(
- LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_);
+ constexpr auto EqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
+ return AppendImpl(LayoutPatternEqualImpl(layout));
}
// Modifies the pattern to match only if the layout has a dense format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithDenseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_);
+ constexpr auto WithDenseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(DENSE));
}
// Modifies the pattern to match only if the layout has a sparse format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithSparseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_);
+ constexpr auto WithSparseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(SPARSE));
}
private:
@@ -228,8 +300,72 @@ class LayoutPattern {
LayoutType** matched_layout_;
};
+template <typename Item, typename... Patterns>
+class AnyOfPattern {
+ public:
+ explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ auto new_option = option;
+ new_option.capture = false;
+ // Try to match the sub-pattern without capturing behavior.
+ if (std::get<index>(patterns_).Match(item, new_option)) {
+ // Capture the branch.
+ if (option.capture) {
+ // TODO(timshen): Currently the behavior can be exponential. Optimize it
+ // with memoization or recording the matched sub-pattern index, if it
+ // takes too long to run.
+ //
+ // Specifically, the "memoization" approach is to create an empty
+ // container with the key (pattern, instruction), and value as whether
+ // matched or not.
+ //
+ // Alternatively, we may run the pattern matching with captures off, but
+ // instead record a "trace" somewhere, indicating how exactly the
+ // pattern matches the input. For example, the trace information for
+ // AnyOf will be a runtime number indicate which sub-pattern is matched.
+ // Then we run another pass to do captures only with the help of the
+ // trace.
+ bool ret = std::get<index>(patterns_).Match(item, option);
+ DCHECK(ret);
+ }
+ return true;
+ }
+ return MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return false;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
} // namespace detail
+// Returns a pattern that represents the logical disjunction of the input
+// patterns. The returned pattern matches from left to right, and stops on the
+// first match.
+template <typename Item, typename... Patterns>
+detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
+ const Patterns&... patterns) {
+ return detail::AnyOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
// Creates a layout pattern that will capture the matched layout in the
// argument.
inline constexpr detail::LayoutPattern<const ::xla::Layout,
@@ -258,172 +394,145 @@ class ShapePattern;
// nullptr.
class ShapePatternBaseImpl {
public:
- bool Match(const ::xla::Shape* shape) const { return shape != nullptr; }
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape != nullptr;
+ }
};
// A ShapePattern implementation that matches only if the shape equals a Shape
// proto.
-template <typename Previous>
class ShapePatternEqualImpl {
public:
- explicit constexpr ShapePatternEqualImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Equal(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape is compatible to
// a Shape proto.
-template <typename Previous>
class ShapePatternCompatibleImpl {
public:
- explicit constexpr ShapePatternCompatibleImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Compatible(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape has a given
// element type.
-template <typename Previous>
class ShapePatternElementTypeImpl {
public:
- explicit constexpr ShapePatternElementTypeImpl(const Previous& previous,
- PrimitiveType element_type)
- : previous_(previous), element_type_(element_type) {}
+ explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
+ : element_type_(element_type) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && shape->element_type() == element_type_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape->element_type() == element_type_;
}
private:
- Previous previous_;
PrimitiveType element_type_;
};
// A ShapePattern implementation that matches only if the shape is scalar.
-template <typename Previous>
class ShapePatternIsScalarImpl {
public:
- explicit constexpr ShapePatternIsScalarImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsScalarImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsScalar(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsScalar(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is an array
-template <typename Previous>
class ShapePatternIsArrayImpl {
public:
- explicit constexpr ShapePatternIsArrayImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsArrayImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsArray(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsArray(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is a tuple.
-template <typename Previous>
class ShapePatternIsTupleImpl {
public:
- explicit constexpr ShapePatternIsTupleImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsTupleImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsTuple(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsTuple(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape has a given
// rank.
-template <typename Previous>
class ShapePatternRankImpl {
public:
- explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank)
- : previous_(previous), rank_(rank) {}
+ explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Rank(*shape) == rank_;
}
private:
- Previous previous_;
int64 rank_;
};
// A ShapePattern implementation that matches only if the shape has a layout
// that matches a given pattern.
-template <typename Previous, typename LayoutType, typename LayoutImpl>
+template <typename LayoutType, typename LayoutImpl>
class ShapePatternLayoutImpl {
public:
explicit constexpr ShapePatternLayoutImpl(
- const Previous& previous,
const LayoutPattern<LayoutType, LayoutImpl>& layout)
- : previous_(previous), layout_(layout) {}
+ : layout_(layout) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(&shape->layout());
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(&shape->layout(), option);
}
- bool Match(Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(shape->mutable_layout());
+ bool Match(Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(shape->mutable_layout(), option);
}
private:
- Previous previous_;
LayoutPattern<LayoutType, LayoutImpl> layout_;
};
// A ShapePattern implementation that matches only if the shape has a subshape
// that matches a given pattern.
-template <typename Previous, typename SubshapeType, typename SubshapeImpl>
+template <typename SubshapeType, typename SubshapeImpl>
class ShapePatternSubshapeImpl {
public:
explicit ShapePatternSubshapeImpl(
- const Previous& previous, ShapeIndexView index,
+ ShapeIndexView index,
const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
- : previous_(previous), index_(index), subshape_(subshape) {}
+ : index_(index), subshape_(subshape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_));
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option);
}
- bool Match(::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_));
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_),
+ option);
}
private:
- Previous previous_;
ShapeIndexView index_;
ShapePattern<SubshapeType, SubshapeImpl> subshape_;
};
@@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl {
// A pattern that matches Shapes.
template <typename ShapeType, typename Impl>
class ShapePattern {
+ private:
+ template <typename NewImpl>
+ ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>> AppendImpl(
+ NewImpl new_impl) const {
+ return ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>>(
+ AllOf<Shape>(impl_, std::move(new_impl)), matched_shape_);
+ }
+
public:
explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
: impl_(impl), matched_shape_(matched_shape) {}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(const ::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -447,9 +564,9 @@ class ShapePattern {
}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -459,108 +576,90 @@ class ShapePattern {
// Modifies the pattern to match only if the shape equals the given proto.
// The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo(
- const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>(
- ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto EqualTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
+ return AppendImpl(ShapePatternEqualImpl(shape));
}
// Modifies the pattern to match only if the shape is compatible to the given
// proto. The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>
- CompatibleTo(const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>(
- ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto CompatibleTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
+ return AppendImpl(ShapePatternCompatibleImpl(shape));
}
// Modifies the pattern to match only if the shape has the given element type.
- constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>
- WithElementType(PrimitiveType element_type) const {
- return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>(
- ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_);
+ constexpr auto WithElementType(PrimitiveType element_type) const
+ -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
+ return AppendImpl(ShapePatternElementTypeImpl(element_type));
}
// Modifies the pattern to match only if the shape is scalar.
- constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar()
- const {
- return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>(
- ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsScalar() const
+ -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
+ return AppendImpl(ShapePatternIsScalarImpl());
}
// Modifies the pattern to match only if the shape is an array.
- constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray()
- const {
- return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>(
- ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsArray() const
+ -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
+ return AppendImpl(ShapePatternIsArrayImpl());
}
// Modifies the pattern to match only if the shape is a tuple.
- constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple()
- const {
- return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>(
- ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsTuple() const
+ -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
+ return AppendImpl(ShapePatternIsTupleImpl());
}
// Modifies the pattern to match only if the shape has the given rank.
- constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank(
- int64 rank) const {
- return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>(
- ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_);
+ constexpr auto WithRank(int64 rank) const
+ -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
+ return AppendImpl(ShapePatternRankImpl(rank));
}
// Modifies the pattern to match only if the shape has a layout that matches
// the given pattern.
template <typename LayoutType, typename LayoutImpl>
- constexpr ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>
- WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
- return ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>(
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout),
- matched_shape_);
- }
-
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternEqualImpl<LayoutPatternBaseImpl>>>
- WithLayoutEqualTo(const ::xla::Layout* layout) const {
+ auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
+ -> decltype(this->AppendImpl(
+ ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
+ return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
+ }
+
+ constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
return WithLayout(Layout().EqualTo(layout));
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsDenseArray() const {
+ constexpr auto IsDenseArray() const
+ -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
return WithLayout(Layout().WithDenseFormat());
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsSparseArray() const {
+ constexpr auto IsSparseArray() const
+ -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
return WithLayout(Layout().WithSparseFormat());
}
// Modifies the pattern to match only if the shape has a subshape that matches
// the given pattern.
template <typename SubshapeType, typename SubshapeImpl>
+ auto WithSubshape(ShapeIndexView index,
+ const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
+ const -> decltype(this->AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
+ subshape))) {
+ return AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
+ }
+
ShapePattern<ShapeType,
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>
- WithSubshape(ShapeIndexView index,
- const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
- return ShapePattern<
- ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>(
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index,
- subshape),
- matched_shape_);
- }
-
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternEqualImpl<ShapePatternBaseImpl>>>
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternEqualImpl>>>>
WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
return WithSubshape(index,
ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
@@ -568,9 +667,12 @@ class ShapePattern {
.EqualTo(shape));
}
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternCompatibleImpl<ShapePatternBaseImpl>>>
+ ShapePattern<ShapeType,
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternCompatibleImpl>>>>
WithSubshapeCompatibleTo(ShapeIndexView index,
const ::xla::Shape* shape) const {
return WithSubshape(index,
@@ -611,159 +713,169 @@ class HloInstructionPattern;
// instruction is not nullptr.
class HloInstructionPatternBaseImpl {
public:
- bool Match(const ::xla::HloInstruction* inst) const {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
return inst != nullptr;
}
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given name.
-template <typename Previous>
class HloInstructionPatternNameImpl {
public:
- explicit HloInstructionPatternNameImpl(const Previous& previous,
- absl::string_view name)
- : previous_(previous), name_(name) {}
+ explicit HloInstructionPatternNameImpl(absl::string_view name)
+ : name_(name) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->name() == name_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->name() == name_;
}
private:
- Previous previous_;
absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given opcode.
-template <typename Previous>
class HloInstructionPatternOpcodeImpl {
public:
- explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous,
- HloOpcode opcode,
+ explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
bool invert)
- : previous_(previous), opcode_(opcode), invert_(invert) {}
+ : opcode_(opcode), invert_(invert) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return (invert_ ^ (inst->opcode() == opcode_));
}
private:
- Previous previous_;
HloOpcode opcode_;
bool invert_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a shape that matches a given pattern.
-template <typename Previous, typename ShapeType, typename ShapeImpl>
+template <typename ShapeType, typename ShapeImpl>
class HloInstructionPatternShapeImpl {
public:
explicit constexpr HloInstructionPatternShapeImpl(
- const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape)
- : previous_(previous), shape_(shape) {}
+ const ShapePattern<ShapeType, ShapeImpl>& shape)
+ : shape_(shape) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(&inst->shape());
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(&inst->shape(), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(inst->mutable_shape());
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(inst->mutable_shape(), option);
}
private:
- Previous previous_;
ShapePattern<ShapeType, ShapeImpl> shape_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has an operand that matches a given pattern.
-template <typename Previous, typename OperandType, typename OperandImpl>
+template <typename OperandType, typename OperandImpl>
class HloInstructionPatternOperandImpl {
public:
explicit constexpr HloInstructionPatternOperandImpl(
- const Previous& previous, int64 operand_index,
+ int64 operand_index,
const HloInstructionPattern<OperandType, OperandImpl>& operand)
- : previous_(previous), operand_index_(operand_index), operand_(operand) {}
+ : operand_index_(operand_index), operand_(operand) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->operand(operand_index_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->operand(operand_index_), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->mutable_operand(operand_index_));
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->mutable_operand(operand_index_), option);
}
private:
- Previous previous_;
int64 operand_index_;
HloInstructionPattern<OperandType, OperandImpl> operand_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a fusion node with a particular kind.
-template <typename Previous>
class HloInstructionPatternFusionKindImpl {
public:
explicit constexpr HloInstructionPatternFusionKindImpl(
- const Previous& previous, ::xla::HloInstruction::FusionKind kind)
- : previous_(previous), kind_(kind) {}
+ ::xla::HloInstruction::FusionKind kind)
+ : kind_(kind) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
private:
- Previous previous_;
::xla::HloInstruction::FusionKind kind_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a kGetTupleElement with a particular tuple index.
-template <typename Previous>
class HloInstructionPatternTupleIndexImpl {
public:
- explicit constexpr HloInstructionPatternTupleIndexImpl(
- const Previous& previous, int64 tuple_index)
- : previous_(previous), tuple_index_(tuple_index) {}
+ explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
+ : tuple_index_(tuple_index) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
private:
- Previous previous_;
int64 tuple_index_;
};
+template <typename ItemType, typename Predicate>
+class HloPredicatePatternImpl {
+ public:
+ explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {}
+
+ bool Match(const ItemType* item, MatchOption option) const {
+ return pred_(item);
+ }
+
+ bool Match(ItemType* item, MatchOption option) const { return pred_(item); }
+
+ private:
+ Predicate pred_;
+};
+
+struct PatternFriend;
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
+ private:
+ template <typename NewImpl>
+ HloInstructionPattern<HloInstructionType,
+ AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return HloInstructionPattern<
+ HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>(
+ AllOf<HloInstruction>(impl_, std::move(new_impl)), matched_inst_);
+ }
+
public:
explicit constexpr HloInstructionPattern(const Impl& impl,
HloInstructionType** matched_inst)
: impl_(impl), matched_inst_(matched_inst) {}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(const ::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -772,9 +884,9 @@ class HloInstructionPattern {
}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -783,102 +895,87 @@ class HloInstructionPattern {
}
// Modifies the pattern to match only if the instruction has the given name.
- HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(absl::string_view name) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternNameImpl<Impl>>(
- HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
+ auto WithName(absl::string_view name) const
+ -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
+ return AppendImpl(HloInstructionPatternNameImpl(name));
}
// Modifies the pattern to match only if the instruction has the given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false),
- matched_inst_);
+ auto WithOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ false))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
}
// Modifies the pattern to match only if the instruction does not have the
// given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithoutOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true),
- matched_inst_);
+ auto WithoutOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ true))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
}
// Modifies the pattern to match only if the instruction is a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsConstant() const {
+ constexpr auto IsConstant() const
+ -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
return WithOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction is not a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsNonConstant() const {
+ constexpr auto IsNonConstant() const
+ -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
return WithoutOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction has a shape that
// matches the given pattern.
template <typename ShapeType, typename ShapeImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>
- WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>(
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_,
- shape),
- matched_inst_);
+ constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
+ const -> decltype(this->AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
+ return AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
}
// Modifies the pattern to match only if the instruction has an operand that
// matches the given pattern.
template <typename OperandType, typename OperandImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>
- WithOperand(
+ constexpr auto WithOperand(
int64 operand_index,
- const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>(
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>(
- impl_, operand_index, operand),
- matched_inst_);
+ const HloInstructionPattern<OperandType, OperandImpl>& operand) const
+ -> decltype(this->AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand))) {
+ return AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand));
}
// Modifies the pattern to match only if the instruction is a fusion node with
// the given kind.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>
- WithFusionKind(HloInstruction::FusionKind kind) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>(
- HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
+ constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
+ -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
+ return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
}
// Modifies the pattern to match only if the instruction is a
// get-tuple-element with the given tuple index.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>
- WithTupleIndex(int64 tuple_index) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>(
- HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index),
- matched_inst_);
+ constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
+ this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
+ return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
}
private:
+ template <typename Predicate>
+ constexpr auto WithPredicate(Predicate pred) const -> decltype(
+ this->AppendImpl(HloPredicatePatternImpl<HloInstruction, Predicate>(
+ std::move(pred)))) {
+ return AppendImpl(
+ HloPredicatePatternImpl<HloInstruction, Predicate>(std::move(pred)));
+ }
+
+ friend struct PatternFriend;
+
Impl impl_;
HloInstructionType** matched_inst_;
};
@@ -1005,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose)
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)); \
}
-XLA_BINOP_PATTERN(Add)
+
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
+ XLA_BINOP_PATTERN(NAME) \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \
+ return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs))) { \
+ return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs)); \
+ }
+XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Dot)
-XLA_BINOP_PATTERN(Eq)
+XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
XLA_BINOP_PATTERN(Gt)
XLA_BINOP_PATTERN(Le)
XLA_BINOP_PATTERN(Lt)
-XLA_BINOP_PATTERN(Maximum)
-XLA_BINOP_PATTERN(Minimum)
-XLA_BINOP_PATTERN(Multiply)
-XLA_BINOP_PATTERN(Ne)
+XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
+XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
-XLA_BINOP_PATTERN(And)
-XLA_BINOP_PATTERN(Or)
+XLA_COMMUTATIVE_BINOP_PATTERN(And)
+XLA_COMMUTATIVE_BINOP_PATTERN(Or)
XLA_BINOP_PATTERN(ShiftLeft)
XLA_BINOP_PATTERN(ShiftRightArithmetic)
XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_COMMUTATIVE_BINOP_PATTERN
#undef XLA_BINOP_PATTERN
// Helpers for ternary instructions.
@@ -1070,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp);
XLA_TERNOP_PATTERN(Select);
#undef XLA_TERNOP_PATTERN
+namespace detail {
+struct PatternFriend {
+ template <typename T>
+ static auto ConstantScalar(T constant) -> decltype(
+ Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(
+ std::declval<std::function<bool(const HloInstruction*)>>())) {
+ std::function<bool(const HloInstruction*)> pred =
+ [constant](const HloInstruction* instr) {
+ const auto& literal = Cast<HloConstantInstruction>(instr)->literal();
+ auto status_or_const = LiteralUtil::CreateR0(constant).Convert(
+ literal.shape().element_type());
+ return status_or_const.ok() &&
+ literal == status_or_const.ConsumeValueOrDie();
+ };
+
+ return Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(std::move(pred));
+ }
+};
+} // namespace detail
+
// Helpers for matching non-constant instructions.
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
return Op().IsNonConstant();
@@ -1107,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
.WithTupleIndex(tuple_index);
}
+template <typename T>
+inline auto ConstantScalar(T constant)
+ -> decltype(detail::PatternFriend::ConstantScalar(constant)) {
+ return detail::PatternFriend::ConstantScalar(constant);
+}
+
} // namespace match
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index a530581c34..3ab7b7fd71 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) {
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
}
+TEST(PatternMatcherTest, AnyOf) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(1))));
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
+ match::ConstantScalar(0))));
+ EXPECT_FALSE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(2))));
+}
+
+TEST(PatternMatcherTest, ConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
+}
+
+TEST(PatternMatcherTest, NoMatchConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_FALSE(Match(root, match::ConstantScalar(42)));
+}
+
+TEST(PatternMatcherTest, MultiplyAnyOrder) {
+ using match::ConstantScalar;
+ using match::MultiplyAnyOrder;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ const HloInstruction* instr;
+
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
+}
+
+TEST(PatternMatcherTest, AnyOfShortCircuit) {
+ using match::AnyOf;
+ using match::Multiply;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(nullptr, any);
+ }
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
+ EXPECT_NE(nullptr, any);
+ EXPECT_EQ(nullptr, mul);
+ }
+}
+
+TEST(PatternMatcherTest, AllOf) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
+ auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16));
+ ASSERT_TRUE(Match(root, scalar_pattern));
+ ASSERT_TRUE(Match(root, f16_pattern));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern)));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
+}
+
+TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_FALSE(
+ Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
+ EXPECT_EQ(nullptr, constant);
+ ASSERT_TRUE(Match(root, Constant(&constant)));
+ EXPECT_NE(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestNoCapture) {
+ using match::Constant;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
+ EXPECT_EQ(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ u = f16[] parameter(0)
+ v = f16[] parameter(1)
+ ROOT add = f16[] add(u, v)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* addend0 = nullptr;
+ const HloInstruction* addend1 = nullptr;
+ const HloInstruction* addend2 = nullptr;
+ auto add2_pattern = Add(Op(&addend0), Op(&addend1));
+ auto add3_pattern = AnyOf<HloInstruction>(
+ AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
+
+ ASSERT_TRUE(Match(root, add3_pattern));
+ EXPECT_NE(nullptr, addend0);
+ EXPECT_NE(nullptr, addend1);
+ EXPECT_EQ(nullptr, addend2);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 178a78ede0..c522e7ae23 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
// thread. Because we parallelize a single computation across threads, it
- // doesn't make sense to expose these as separate devices, so fix the number
- // of devices to one.
- device_count = 1;
+ // doesn't make sense to expose these as separate devices, so by default we
+ // fix the number of devices to one. However we do let the user override
+ // this behavior to help run tests on the host that run models in parallel
+ // across multiple devices.
+ device_count = legacy_flags::GetDebugOptionsFromFlags()
+ .xla_force_host_platform_device_count();
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 256b231e3a..0b4e82e8d6 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -22,14 +22,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
-class ReducePrecisionInsertion : public HloPassInterface {
+class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1e86a0823a..a3db439e34 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -24,7 +24,7 @@ namespace xla {
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
-class ReshapeMover : public HloPassInterface {
+class ReshapeMover : public HloModulePass {
public:
absl::string_view name() const override { return "reshape-mover"; }
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 2f4b2667c4..de7aee262e 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -155,6 +155,53 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}
+static StatusOr<HloInstruction*> CheckIndexValidity(
+ HloComputation* computation, HloInstruction* index,
+ absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ DCHECK_EQ(operand_dims.size(), window_sizes.size());
+
+ // Valid range for the index: [0, operand_dims - window_sizes]
+
+ // Check if the index has any negative values.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * zero_index,
+ BroadcastZeros(computation, index->shape().element_type(),
+ AsInt64Slice(index->shape().dimensions())));
+ TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
+ MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
+
+ // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
+ std::vector<int64> max_valid_index(operand_dims.size());
+ for (int i = 0; i < operand_dims.size(); ++i) {
+ max_valid_index[i] = operand_dims[i] - window_sizes[i];
+ }
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * max_valid_index_constant,
+ MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
+ max_valid_index));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * oob_index_check,
+ MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
+
+ // Combine the results of the two checks above.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index,
+ MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
+
+ // Reduce the index validity check vector into a scalar predicate.
+ auto reduction_init = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index_reduced,
+ MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
+
+ // Return a broadcasted value of the scalar predicate to the same size as the
+ // window.
+ return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
+}
+
// Body of the while loop that performs the scatter operation using other HLOs.
static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
HloInstruction* scatter, HloInstruction* induction_var,
@@ -222,7 +269,16 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
InsertDegenerateDims(update_slice_for_scatter,
AsInt64Slice(dim_numbers.inserted_window_dims())));
- // Extact the slice to update from `operand` tensor.
+ // Note that the following transformation assumes that both DynamicSlice and
+ // DynamicUpdateSlice follow the same semantics for OOB indices. For example,
+ // if there are negative indices and DynamicSlice uses "clamping" semantics,
+ // then the extracted data will be "shifted". Since DynamicUpdateSlice also
+ // follows the same "clamping" semantics, writing the update will also be
+ // "shifted" by exactly the same amount. So, this transformation is correct as
+ // long as the semantics of handling OOB indices remain the same in
+ // DynamicSlice and DynamicUpdateSlice.
+
+ // Extract the slice to update from `operand` tensor.
const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
TF_ASSIGN_OR_RETURN(
HloInstruction * operand_slice_to_update,
@@ -237,10 +293,24 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
scatter->to_apply()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * is_index_valid,
+ CheckIndexValidity(
+ operand->parent(), scatter_slice_start,
+ AsInt64Slice(operand->shape().dimensions()),
+ AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()),
+ scatter->GetModule()));
+
+ // Select the updated operand only if the index is valid. If not, select the
+ // original value.
+ TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply,
+ MakeSelectHlo(is_index_valid, updated_operand_slice,
+ operand_slice_to_update));
+
// Write the updated value of the slice into `operand` tensor.
- TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
- MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
- scatter_slice_start));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * updated_operand,
+ MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start));
return StatusOr<std::vector<HloInstruction*>>{
{updated_operand, scatter_indices, updates}};
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 14f062c89c..559a85dccf 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -20,7 +20,7 @@ limitations under the License.
namespace xla {
-class ScatterExpander : public HloPassInterface {
+class ScatterExpander : public HloModulePass {
public:
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 74bdf2a2e3..e379911462 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers(
// Check that dimension numbers are unique.
auto dims_unique = [](absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) -> bool {
- tensorflow::gtl::FlatSet<int64> dim_set;
+ absl::flat_hash_set<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
};
@@ -1665,10 +1665,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %d) to match RHS "
- "input feature dimension * feature_group_count (value %d); "
+ "input feature dimension * feature_group_count (value %d * %d = %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features * feature_group_count,
+ input_features, kernel_input_features, feature_group_count,
+ kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
@@ -2379,7 +2380,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
return InvalidArgument(
- "Transpose dimensions not a permutation of the operand dimensions.");
+ "Transpose dimensions [%s] are not a permutation of the operand "
+ "dimensions (operand shape is %s).",
+ StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
}
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 921a984589..56952e3ada 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() {
// Deallocate all non-null buffers. A buffer may appear in more than one spot
// in the shape (eg, a tuple with a repeated element) so keep track of what
// has been deallocated.
- tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
+ absl::flat_hash_set<void*> deallocated_ptrs;
for (auto& pair : buffers_) {
se::DeviceMemoryBase& memory_base = pair.second;
if (!memory_base.is_null() &&
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c442..ec09dff924 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
- VLOG(1) << stream->DebugStreamPointers()
- << " StreamPool reusing existing stream";
+ if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
+ } else {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " stream was not ok, StreamPool deleting";
+ stream = nullptr;
+ }
}
}
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37b0d..92f47579d3 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Return the stream, but hold a handle to it.
+ se::Stream* stream1_ptr = stream1.get();
+ stream1 = nullptr;
+
+ // Now stream1 is back in the pool, force an error on the stream. Here we call
+ // a method that requires DNN support, which we know the Host platform doesn't
+ // support.
+ stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1_ptr->ok());
+
+ // Borrow stream2.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 3e5aa2db60..f95f982eb8 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
-class TransposeFolding : public HloPassInterface {
+class TransposeFolding : public HloModulePass {
public:
using OperandIndices = std::vector<int64>;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 6fed7c76d0..811ac55e2d 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -280,16 +280,6 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
return Status::OK();
}
-Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) {
- // A kSlice instruction aliases its operand if the backend lowers it to an
- // in-place implementation.
- if (slice->IsInPlaceSlice()) {
- CreateCopiedPointsToSet(slice, slice->operand(0));
- return Status::OK();
- }
- return DefaultAction(slice);
-}
-
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
// output. The other indices ({} and {1}) define their own buffers.
@@ -455,15 +445,10 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
- // kSlice ops that are lowered to an in-place version are expected to not
- // define their output buffer.
- if (buffer.instruction()->opcode() != HloOpcode::kSlice ||
- !buffer.instruction()->IsInPlaceSlice()) {
- return FailedPrecondition(
- "LogicalBuffer %s is ill-defined: instruction %s does not define a "
- "buffer at that index",
- buffer.ToString(), buffer.instruction()->name());
- }
+ return FailedPrecondition(
+ "LogicalBuffer %s is ill-defined: instruction %s does not define a "
+ "buffer at that index",
+ buffer.ToString(), buffer.instruction()->name());
}
if (buffer.id() < 0 ||
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index a9e8a51e09..30c365053c 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -36,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -249,7 +247,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
- Status HandleSlice(HloInstruction* slice) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 8c91d6e69d..e126a53023 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
// the module.
-class TupleSimplifier : public HloPassInterface {
+class TupleSimplifier : public HloModulePass {
public:
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 56145822be..067cfcc17d 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 2dba7d7f75..577bad6c70 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -50,7 +50,7 @@ namespace xla {
// conditions as well.
//
// TODO(b/79121449): We should also sink broadcasts of constants.
-class WhileLoopConstantSinking : public HloPassInterface {
+class WhileLoopConstantSinking : public HloModulePass {
public:
~WhileLoopConstantSinking() override = default;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index e8fe33e626..9795b2830b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -15,18 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
+using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::InlinedVector;
-using tensorflow::gtl::FlatMap;
-using tensorflow::gtl::FlatSet;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
@@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet;
// function hoists the operands in `unhoisted_invariant_instructions` and moves
// them into `hoisted_instructions`.
static void CreateLoopInvariantCopy(
- FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
- FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
+ flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions,
+ flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions,
HloInstruction* while_instr, HloInstruction* to_hoist) {
HloComputation* parent_of_while = while_instr->parent();
HloComputation* while_body = while_instr->while_body();
@@ -147,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
// Maps instructions in the while body to instructions hoisted outside the
// while that compute the same value.
- FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
+ flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
// Contains instructions that can be legally hoisted, but were deemed to be
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
// hoist an instruction in this set, we move it from
// unhoisted_invariant_instructions to hoisted_instructions.
- FlatSet<HloInstruction*> unhoisted_invariant_instructions;
+ flat_hash_set<HloInstruction*> unhoisted_invariant_instructions;
// Invariant GTE's axiomatically satisfy the constraints for
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 2cdf20ce80..3031899f71 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
-class WhileLoopInvariantCodeMotion : public HloPassInterface {
+class WhileLoopInvariantCodeMotion : public HloModulePass {
public:
// If `hoist_constants` is true then constants are always hoisted out of while
// loop bodies. Otherwise they are only hoisted out if they enable other
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 6a7bfe3f12..630d71e5ca 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,12 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
return false;
}
- tensorflow::gtl::FlatSet<int64> used_tuple_indices;
+ absl::flat_hash_set<int64> used_tuple_indices;
for (HloComputation* comp : {while_body, while_cond}) {
// The HLO verifier ensures that while_input's shape matches while_init's
// shape, which we verified above is a tuple.
@@ -181,7 +182,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
used_tuple_indices.end());
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
- tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx;
+ absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
int64 old_idx = new_to_old_tuple_idx[new_idx];
old_to_new_tuple_idx[old_idx] = new_idx;
@@ -252,7 +253,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Create the new while condition, body, and init value.
std::unique_ptr<HloComputation> new_while_cond =
while_cond->CloneWithReplacements(
- make_while_computation_replacements(while_cond));
+ make_while_computation_replacements(while_cond), /*extras=*/{});
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
while_body_replacements = make_while_computation_replacements(while_body);
@@ -265,7 +266,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
while_body_replacements.emplace(
while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
std::unique_ptr<HloComputation> new_while_body =
- while_body->CloneWithReplacements(std::move(while_body_replacements));
+ while_body->CloneWithReplacements(std::move(while_body_replacements),
+ /*extras=*/{});
// Add a new while_init instruction that repackages the old while_init
// instruction's elements. We rely on the AlgebraicSimplifier and DCE to
@@ -404,7 +406,7 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
// build a map from the tuple element index to the constant value. Limit this
// to scalar constant values because propagating array constants can regress
// performance by forcing us to copy constants.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
+ absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
for (int i = 0; i < root_operands.size(); i++) {
HloInstruction* instr = root_operands[i];
if (instr->opcode() == HloOpcode::kGetTupleElement &&
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 78024f14dc..0bc5a0107b 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -30,7 +30,7 @@ namespace xla {
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
//
-class WhileLoopSimplifier : public HloPassInterface {
+class WhileLoopSimplifier : public HloModulePass {
public:
~WhileLoopSimplifier() override {}
absl::string_view name() const override { return "simplify-while-loops"; }
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index a7f0e207eb..87294120d5 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -21,7 +21,7 @@ limitations under the License.
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
-class ZeroSizedHloElimination : public HloPassInterface {
+class ZeroSizedHloElimination : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 96c80fd577..d244923532 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
- CHECK_EQ(shape.dimensions_size(), Rank(shape));
+ DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
+ DCHECK_EQ(shape.dimensions_size(), Rank(shape));
+ if (shape.dimensions().size() == 1) {
+ return shape.dimensions()[0];
+ }
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
std::multiplies<int64>());
@@ -828,7 +831,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
- if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
+ !PrimitiveType_IsValid(shape.element_type())) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString());
}
@@ -865,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return Status::OK();
}
- if (Rank(shape) != shape.dimensions_size()) {
- return InvalidArgument(
- "shape's rank is mismatched with dimension count; rank=%d "
- "dimensions_size=%d",
- Rank(shape), shape.dimensions_size());
+ if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) {
+ return InvalidArgument("sparse arrays must have rank > 0");
}
for (int64 i = 0; i < Rank(shape); ++i) {
int64 dimension = shape.dimensions(i);
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 623ae39de8..d8bb27beae 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <initializer_list>
#include <string>
+#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -479,8 +480,7 @@ class ShapeUtil {
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
- //
- // DEPRECATED: Use Equal() instead.
+ ABSL_DEPRECATED("Use Equal() instead.")
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 30e3077edb..8a0ae33042 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites"
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@@ -150,11 +154,31 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
@@ -398,6 +422,7 @@ xla_test(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
@@ -1797,7 +1822,7 @@ xla_test(
tf_cc_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -2096,7 +2121,7 @@ tf_cc_test(
name = "sample_file_test",
srcs = ["sample_file_test.cc"],
data = ["isolated_convolution.hlo"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
@@ -2121,11 +2146,11 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -2144,3 +2169,21 @@ xla_test(
"//tensorflow/core:lib",
],
)
+
+tf_cc_test(
+ name = "multiple_devices_on_host_test",
+ srcs = ["multiple_devices_on_host_test.cc"],
+ args = ["--xla_force_host_platform_device_count=4"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 53f2c3bfbf..05d4d04034 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -3,256 +3,266 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
- """Removes "gpu" from a backend list if CUDA is not enabled.
-
- This allows us to simply hardcode lists including "gpu" here and in the
- BUILD file, without causing failures when CUDA isn't enabled.'
-
- Args:
- backends: A list of backends to filter.
-
- Returns:
- The filtered list of backends.
- """
- if cuda_is_configured():
- return backends
- else:
- return [backend for backend in backends if backend != "gpu"]
-
-
-def xla_test(name,
- srcs,
- deps,
- xla_test_library_deps=[],
- backends=[],
- blacklisted_backends=[],
- args=[],
- tags=[],
- copts=[],
- data=[],
- backend_tags={},
- backend_args={},
- **kwargs):
- """Generates cc_test targets for the given XLA backends.
-
- This rule generates a cc_test target for one or more XLA backends and also a
- platform-agnostic cc_library rule. The arguments are identical to cc_test with
- two additions: 'backends' and 'backend_args'. 'backends' specifies the
- backends to generate tests for ("cpu", "gpu"), and
- 'backend_args'/'backend_tags' specifies backend-specific args parameters to
- use when generating the cc_test.
-
- The name of the cc_tests are the provided name argument with the backend name
- appended, and the cc_library target name is the provided name argument with
- "_lib" appended. For example, if name parameter is "foo_test", then the cpu
- test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
-
- The cc_library target can be used to link with other plugins outside of
- xla_test.
-
- The build rule also defines a test suite ${name} which includes the tests for
- each of the supported backends.
-
- Each generated cc_test target has a tag indicating which backend the test is
- for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
- tags can be used to gather tests for a particular backend into a test_suite.
-
- Examples:
-
- # Generates the targets: foo_test_cpu and foo_test_gpu.
- xla_test(
- name = "foo_test",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
+ """Removes "gpu" from a backend list if CUDA is not enabled.
- # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
- # includes the additional arg "--special_cpu_flag".
- xla_test(
- name = "bar_test",
- srcs = ["bar_test.cc"],
- backends = ["cpu", "gpu"],
- backend_args = {"cpu": ["--special_cpu_flag"]}
- deps = [...],
- )
+ This allows us to simply hardcode lists including "gpu" here and in the
+ BUILD file, without causing failures when CUDA isn't enabled.'
- The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
- to the value 1 where ${BACKEND} is the uppercase name of the backend.
-
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- xla_test_library_deps: If set, the generated test targets will depend on the
- respective cc_libraries generated by the xla_test_library rule.
- backends: A list of backends to generate tests for. Supported values: "cpu",
- "gpu". If this list is empty, the test will be generated for all supported
- backends.
- blacklisted_backends: A list of backends to NOT generate tests for.
- args: Test arguments for the target.
- tags: Tags for the target.
- copts: Additional copts to pass to the build.
- data: Additional data to pass to the build.
- backend_tags: A dict mapping backend name to list of additional tags to
- use for that target.
- backend_args: A dict mapping backend name to list of additional args to
- use for that target.
- **kwargs: Additional keyword arguments to pass to native.cc_test.
- """
- test_names = []
- if not backends:
- backends = all_backends
-
- backends = [backend for backend in backends
- if backend not in blacklisted_backends]
-
- native.cc_library(
- name="%s_lib" % name,
- srcs=srcs,
- copts=copts,
- testonly=True,
- deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
- )
-
- for backend in filter_backends(backends):
- test_name = "%s_%s" % (name, backend)
- this_backend_tags = ["xla_%s" % backend]
- this_backend_copts = []
- this_backend_args = backend_args.get(backend, [])
- this_backend_data = []
- if backend == "cpu":
- backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- elif backend == "gpu":
- backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
- this_backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_deps = []
- backend_deps += plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
- this_backend_tags += plugins[backend]["tags"]
- this_backend_args += plugins[backend]["args"]
- this_backend_data += plugins[backend]["data"]
- else:
- fail("Unknown backend %s" % backend)
-
- if xla_test_library_deps:
- for lib_dep in xla_test_library_deps:
- backend_deps += ["%s_%s" % (lib_dep, backend)]
-
- tf_cc_test(
- name=test_name,
- srcs=srcs,
- tags=tags + backend_tags.get(backend, []) + this_backend_tags,
- extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
- this_backend_copts,
- args=args + this_backend_args,
- deps=deps + backend_deps,
- data=data + this_backend_data,
- **kwargs)
-
- test_names.append(test_name)
-
- native.test_suite(name=name, tests=test_names)
-
-def xla_test_library(name,
- srcs,
- hdrs=[],
- deps=[],
- backends=[]):
- """Generates cc_library targets for the given XLA backends.
-
- This rule forces the sources to be compiled for each backend so that the
- backend specific macros could expand correctly. It's useful when test targets
- in different directories referring to the same sources but test with different
- arguments.
-
- Examples:
-
- # Generates the targets: foo_test_library_cpu and foo_test_gpu.
- xla_test_library(
- name = "foo_test_library",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
- # Then use the xla_test rule to generate test targets:
- xla_test(
- name = "foo_test",
- srcs = [],
- backends = ["cpu", "gpu"],
- deps = [...],
- xla_test_library_deps = [":foo_test_library"],
- )
+ Args:
+ backends: A list of backends to filter.
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- hdrs: Headers for the target.
- deps: Dependencies of the target.
- backends: A list of backends to generate libraries for.
- Supported values: "cpu", "gpu". If this list is empty, the
- library will be generated for all supported backends.
- """
-
- if not backends:
- backends = all_backends
-
- for backend in filter_backends(backends):
- this_backend_copts = []
- if backend in ["cpu", "gpu"]:
- backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
- elif backend in plugins:
- backend_deps = plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
+ Returns:
+ The filtered list of backends.
+ """
+ if cuda_is_configured():
+ return backends
else:
- fail("Unknown backend %s" % backend)
+ return [backend for backend in backends if backend != "gpu"]
+
+def xla_test(
+ name,
+ srcs,
+ deps,
+ xla_test_library_deps = [],
+ backends = [],
+ blacklisted_backends = [],
+ args = [],
+ tags = [],
+ copts = [],
+ data = [],
+ backend_tags = {},
+ backend_args = {},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
+
+ This rule generates a cc_test target for one or more XLA backends and also a
+ platform-agnostic cc_library rule. The arguments are identical to cc_test with
+ two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
+
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
+
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
+
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
+
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
+
+ Examples:
+
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
+
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ xla_test_library_deps: If set, the generated test targets will depend on the
+ respective cc_libraries generated by the xla_test_library rule.
+ backends: A list of backends to generate tests for. Supported values: "cpu",
+ "gpu". If this list is empty, the test will be generated for all supported
+ backends.
+ blacklisted_backends: A list of backends to NOT generate tests for.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ copts: Additional copts to pass to the build.
+ data: Additional data to pass to the build.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ **kwargs: Additional keyword arguments to pass to native.cc_test.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends
+
+ backends = [
+ backend
+ for backend in backends
+ if backend not in blacklisted_backends
+ ]
native.cc_library(
- name = "%s_%s" % (name, backend),
+ name = "%s_lib" % name,
srcs = srcs,
+ copts = copts,
testonly = True,
- hdrs = hdrs,
- copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
- + this_backend_copts,
- deps = deps + backend_deps,
+ deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
-
-def generate_backend_suites(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- native.test_suite(name="%s_tests" % backend,
- tags = ["xla_%s" % backend])
-
-
-def generate_backend_test_macros(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- manifest = ""
- if backend in plugins:
- manifest = plugins[backend]["disabled_manifest"]
-
- native.cc_library(
- name="test_macros_%s" % backend,
- testonly = True,
- srcs = ["test_macros.cc"],
- hdrs = ["test_macros.h"],
- copts = [
- "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
- "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
- ],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- "//tensorflow/core:test",
- ])
+ for backend in filter_backends(backends):
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ this_backend_data = []
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += tf_cuda_tests_tags()
+ elif backend in plugins:
+ backend_deps = []
+ backend_deps += plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ this_backend_tags += plugins[backend]["tags"]
+ this_backend_args += plugins[backend]["args"]
+ this_backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ if xla_test_library_deps:
+ for lib_dep in xla_test_library_deps:
+ backend_deps += ["%s_%s" % (lib_dep, backend)]
+
+ tf_cc_test(
+ name = test_name,
+ srcs = srcs,
+ tags = tags + backend_tags.get(backend, []) + this_backend_tags,
+ extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args = args + this_backend_args,
+ deps = deps + backend_deps,
+ data = data + this_backend_data,
+ **kwargs
+ )
+
+ test_names.append(test_name)
+
+ native.test_suite(name = name, tests = test_names)
+
+def xla_test_library(
+ name,
+ srcs,
+ hdrs = [],
+ deps = [],
+ backends = []):
+ """Generates cc_library targets for the given XLA backends.
+
+ This rule forces the sources to be compiled for each backend so that the
+ backend specific macros could expand correctly. It's useful when test targets
+ in different directories referring to the same sources but test with different
+ arguments.
+
+ Examples:
+
+ # Generates the targets: foo_test_library_cpu and foo_test_gpu.
+ xla_test_library(
+ name = "foo_test_library",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+ # Then use the xla_test rule to generate test targets:
+ xla_test(
+ name = "foo_test",
+ srcs = [],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ xla_test_library_deps = [":foo_test_library"],
+ )
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ hdrs: Headers for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate libraries for.
+ Supported values: "cpu", "gpu". If this list is empty, the
+ library will be generated for all supported backends.
+ """
+
+ if not backends:
+ backends = all_backends
+
+ for backend in filter_backends(backends):
+ this_backend_copts = []
+ if backend in ["cpu", "gpu"]:
+ backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
+ elif backend in plugins:
+ backend_deps = plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_library(
+ name = "%s_%s" % (name, backend),
+ srcs = srcs,
+ testonly = True,
+ hdrs = hdrs,
+ copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ deps = deps + backend_deps,
+ )
+
+def generate_backend_suites(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ native.test_suite(
+ name = "%s_tests" % backend,
+ tags = ["xla_%s" % backend, "-broken", "manual"],
+ )
+
+def generate_backend_test_macros(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ manifest = ""
+ if backend in plugins:
+ manifest = plugins[backend]["disabled_manifest"]
+
+ native.cc_library(
+ name = "test_macros_%s" % backend,
+ testonly = True,
+ srcs = ["test_macros.cc"],
+ hdrs = ["test_macros.h"],
+ copts = [
+ "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+ "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:test",
+ ],
+ )
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 070b092d18..b851db14ec 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
XlaBuilder builder(TestName());
auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
- Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ PrecisionConfig precision;
+ // The left hand side of the convolution is numbers between 0 and 2304 which
+ // requires at least 11 mantissa bits and the DEFAULT precision config is
+ // allowed to round to bfloat16 which only has 7 mantissa bits.
+ precision.add_operand_precision(PrecisionConfig::HIGHEST);
+ precision.add_operand_precision(PrecisionConfig::DEFAULT);
+ Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
+ &precision);
ComputeAndCompare(&builder, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0171f51583..6c0847a875 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -394,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
ParametricDotTestWithoutLayoutAssignment() {
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"layout-assignment");
+ // Disable algebraic simplification because the pass may replace a dot
+ // instruction with a layout-changing multiplication instruction.
+ execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
+ "algsimp");
}
};
@@ -404,31 +408,18 @@ std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
for (bool lhs_row_major : {true, false}) {
for (bool rhs_row_major : {true, false}) {
for (bool has_addend : {true, false}) {
+ // The addend needs to be row major to match the result of the dot.
params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
/*dot_lhs_row_major=*/lhs_row_major,
/*dot_rhs_row_major=*/rhs_row_major,
/*has_addend=*/has_addend,
/*addend_row_major=*/true});
- if (has_addend) {
- params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
- /*dot_lhs_row_major=*/lhs_row_major,
- /*dot_rhs_row_major=*/rhs_row_major,
- /*has_addend=*/has_addend,
- /*addend_row_major=*/false});
- }
if (n != 1) {
params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
/*dot_lhs_row_major=*/lhs_row_major,
/*dot_rhs_row_major=*/rhs_row_major,
/*has_addend=*/has_addend,
/*addend_row_major=*/true});
- if (has_addend) {
- params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
- /*dot_lhs_row_major=*/lhs_row_major,
- /*dot_rhs_row_major=*/rhs_row_major,
- /*has_addend=*/has_addend,
- /*addend_row_major=*/false});
- }
}
}
}
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 9c94acb437..4d4b676a53 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -764,8 +764,10 @@ XLA_TEST_F(FusionTest, Clamp2D) {
TestElementwise2D<float, 3>(HloOpcode::kClamp);
}
-// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast.
-XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) {
+// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend
+// should not generate layout changing elementwise operations.
+#ifdef XLA_TEST_BACKEND_CPU
+XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) {
const string hlo_text = R"(
HloModule Cluster
@@ -794,6 +796,7 @@ ENTRY main {
LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
result));
}
+#endif
class FusionClientLibraryTest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index bdd4fd7e3d..7ab2ecda58 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
verifier_layout_sensitive,
- allow_mixed_precision_in_hlo_verifier) {}
+ allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: test_runner_(test_platform), reference_runner_(reference_platform) {
hlo_verifier_ = absl::make_unique<HloVerifier>(
/*layout_sensitive=*/verifier_layout_sensitive,
- /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 0ae4bdc104..217428befa 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test {
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
HloTestBase(bool verifier_layout_sensitive = false,
- bool allow_mixed_precision_in_hlo_verifier = true);
+ bool allow_mixed_precision_in_hlo_verifier = true,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {});
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
bool verifier_layout_sensitive = false,
- bool allow_mixed_precision_in_hlo_verifier = true);
+ bool allow_mixed_precision_in_hlo_verifier = true,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {});
~HloTestBase() override {}
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c528d0..8bd0a729b7 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision)
- : HloTestBase(
- /*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
}
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
+ return verifier_.Run(this).status();
}
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
}
}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa753..388a99bb36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
+ std::unique_ptr<VerifiedHloModule> module_;
+
// Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
- bool tear_down_called_ = false;
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000000..5c0263e811
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
new file mode 100644
index 0000000000..c530591c6e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+StatusOr<XlaComputation> BuildComputation() {
+ XlaBuilder b("computation");
+ Shape scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32);
+ return b.Build(
+ OutfeedWithToken(GetTupleElement(infeed, 0) +
+ ConstantLiteral(&b, LiteralUtil::CreateR0<int32>(1)),
+ GetTupleElement(infeed, 1), scalar_s32, ""));
+}
+
+void CompileAndExecute(
+ LocalExecutable* executable, int device_ordinal, LocalClient* client,
+ absl::Mutex* results_mutex,
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>>* results) {
+ xla::ExecutableRunOptions execute_options;
+ execute_options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ execute_options.set_device_ordinal(device_ordinal);
+ execute_options.set_allocator(
+ xla::ClientLibrary::GetXlaService(client->platform())
+ ->backend()
+ .memory_allocator());
+ StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
+ {
+ absl::MutexLock lock(results_mutex);
+ results->emplace_back(device_ordinal, std::move(result));
+ }
+}
+
+void TestWithDeviceCount(const int device_count) {
+ // Run `device_count` copies of the XLA program built by BuildComputation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ se::Platform* const platform,
+ perftools::gputools::MultiPlatformManager::PlatformWithName("Host"));
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform);
+ TF_ASSERT_OK_AND_ASSIGN(
+ LocalClient* const client,
+ xla::ClientLibrary::GetOrCreateLocalClient(client_options));
+
+ TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<LocalExecutable> executable,
+ client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
+ std::vector<tensorflow::Thread*> threads;
+ absl::Mutex results_mutex;
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;
+ tensorflow::Env* env = tensorflow::Env::Default();
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ tensorflow::Thread* t = env->StartThread(
+ tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal),
+ [&executable, device_ordinal, client, &results_mutex, &results] {
+ CompileAndExecute(executable.get(), device_ordinal, client,
+ &results_mutex, &results);
+ });
+ threads.push_back(t);
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(client->TransferToInfeedLocal(
+ LiteralUtil::CreateR0<int32>(device_ordinal * 100), device_ordinal));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
+ client->TransferFromOutfeedLocal(
+ ShapeUtil::MakeShape(S32, {}), device_ordinal));
+ EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ delete threads[device_ordinal];
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(results[device_ordinal].second.status());
+ }
+}
+
+// NB! This test requires --xla_force_host_platform_device_count=4
+
+TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); }
+
+TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); }
+
+TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); }
+
+TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); }
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 63491a90bf..22fe4a2670 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
/*padding=*/padding);
CHECK(reducer == kAdd || reducer == kMax);
@@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
@@ -1303,11 +1308,19 @@ struct R1ReduceWindowTestData {
/*pad_high=*/{0},
/*reducer=*/Reducer::kAdd},
+ // The pattern generated by inclusive scan (cumsum/cumprod).
{/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
/*strides=*/{1},
/*pad_low=*/{4095},
/*pad_high=*/{0},
/*reducer=*/Reducer::kMax},
+
+ // The pattern generated by exclusive scan (cumsum/cumprod).
+ {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+ /*strides=*/{1},
+ /*pad_low=*/{4096},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kMax},
};
string R1ReduceWindowTestDataToString(
@@ -1361,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index d20dba028a..b21dd56045 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -507,6 +507,36 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
+XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
+}
+
XLA_TEST_F(ScatterTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7de6..2cc33ab096 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P(
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
+ R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
R2Spec{
511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
R2Spec{
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 181e5cbe29..bc433eac8f 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<uint32> key_set;
+ absl::flat_hash_set<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}
@@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<int32> key_set;
+ absl::flat_hash_set<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7abd8651d5..8b1b9e1519 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -763,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
-// Test while nodes that share the while body computation.
-// TODO(b/37245345): Fails on GPU backend.
-TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
+TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(F32, {10})};
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index db5a824de0..a6e70eb6ca 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -83,7 +83,7 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
- gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
+ absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results,
absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
string separator = "[^:]*:: +";
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
@@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
std::vector<string> profile_output_lines =
absl::StrSplit(profile_output, '\n');
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
TF_ASSERT_OK(ParseOneProfileOutputLine(
profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
@@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
ASSERT_NE(while_body_profile_end, profile_output_lines.end());
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
for (auto while_body_profile_i = while_body_profile_start + 1;
while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b53f89d63b..60d25a6407 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -200,6 +200,15 @@ message DebugOptions {
// among different algorithms.
bool xla_gpu_crash_on_verification_failures = 101;
+ // Force the host platform to pretend that there are these many host
+ // "devices". All these devices are backed by the same threadpool. Defaults
+ // to 1.
+ //
+ // Setting this to anything other than 1 can increase overhead from context
+ // switching but we let the user override this behavior to help run tests on
+ // the host that run models in parallel across multiple devices.
+ int32 xla_force_host_platform_device_count = 102;
+
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
index fda4c31298..40ec1b0ba9 100644
--- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("XRTExecute")
- .Attr("Ninputs: int")
+ .Attr("Ninputs: int >= 0")
.Input("computation_handle: int64")
.Input("execution_config: string")
.Input("input_handles: Ninputs * int64")
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index 09ab4ed95f..b6dcfc4eb9 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -8,6 +8,10 @@ package(
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "raw_api_test_lib",
@@ -57,7 +61,7 @@ tf_cuda_cc_test(
size = "medium",
srcs = [],
args = ["--xla_test_device=XLA_GPU"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":raw_api_test_lib",
"//tensorflow/compiler/jit:xla_gpu_device",
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 2952feb16a..f590fbf0d9 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a,
return equal;
}
+xla::XlaComputation OnePlusTwo() {
+ xla::XlaBuilder builder("OnePlusTwo");
+ auto c0 = xla::ConstantR0(&builder, 1.0f);
+ auto c1 = xla::ConstantR0(&builder, 2.0f);
+ xla::Add(c0, c1);
+ return builder.Build().ValueOrDie();
+}
+
xla::XlaComputation AddAndScale() {
xla::XlaBuilder builder("AddAndScale");
auto p0 = xla::Parameter(&builder, 0,
@@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
+TEST(RawApiTest, CompileAndExecuteZeroArg) {
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {});
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+ StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ std::initializer_list<Input>({}));
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
+}
+
TEST(RawApiTest, CompileAndExecuteReturnTuple) {
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);